disable activation in the output node of network;

we recommend to use output_transform;
change hyperparameters (strong) in XOR example;
This commit is contained in:
wls2002
2024-05-22 11:09:25 +08:00
parent bb80f12640
commit 1fe5d5fca2
7 changed files with 59 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)