disable activation in the output node of network;
we recommend to use output_transform; change hyperparameters (strong) in XOR example;
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs))
|
||||
|
||||
)
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
@@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
z = bias + res * z
|
||||
|
||||
# the last output node should not be activated
|
||||
if not is_output_node:
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
z = jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: z,
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
@@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
|
||||
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
|
||||
)(conns, values)
|
||||
|
||||
# calculate nodes
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
Reference in New Issue
Block a user