diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index eab7693..302a3c6 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import State, Act, Agg @@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm): activation=Act.sigmoid, aggregation=Agg.sum, activate_time: int = 10, + output_transform: Callable = Act.sigmoid, ): assert substrate.query_coors.shape[1] == neat.num_inputs, \ "Substrate input size should be equal to NEAT input size" @@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm): node_gene=HyperNodeGene(activation, aggregation), conn_gene=HyperNEATConnGene(), activate_time=activate_time, + output_transform=output_transform ) def setup(self, randkey): @@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene): self.activation = activation self.aggregation = aggregation - def forward(self, attrs, inputs): - return self.activation( - self.aggregation(inputs) - ) + def forward(self, attrs, inputs, is_output_node=False): + return jax.lax.cond( + is_output_node, + lambda: self.aggregation(inputs), # output node does not need activation + lambda: self.activation(self.aggregation(inputs)) + ) class HyperNEATConnGene(BaseConnGene): custom_attrs = ['weight'] diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 066ec5e..0b90f2c 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene): z = bias + res * z # the last output node should not be activated - if not is_output_node: - z = act(act_idx, z, self.activation_options) + z = jax.lax.cond( + is_output_node, + lambda: z, + lambda: act(act_idx, z, self.activation_options) + ) return z diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index fcc273c..751b52c 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - z = self.node_gene.forward(nodes_attrs[i], ins) + z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) new_values = values.at[i].set(z) return new_values diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 2f4e630..5ed4737 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import unflatten_conns @@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome): node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), activate_time: int = 10, + output_transform: Callable = None ): super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) self.activate_time = activate_time + if output_transform is not None: + try: + _ = output_transform(jnp.zeros(num_outputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + self.output_transform = output_transform + def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) @@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome): )(conns, values) # calculate nodes - values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T) + is_output_nodes = jnp.isin( + jnp.arange(N), + self.output_idx + ) + values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes) return values vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 94be087..89c8a7d 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -2,6 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d +from utils import Act if __name__ == '__main__': pipeline = Pipeline( @@ -10,13 +11,25 @@ if __name__ == '__main__': genome=DefaultGenome( num_inputs=3, num_outputs=1, - max_nodes=50, - max_conns=100, + max_nodes=100, + max_conns=200, + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + output_transform=Act.sigmoid, # the activation function for output node ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.01, # magic ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000, diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/tensorneat/examples/func_fit/xor3d_hyperneat.py index 6f0e7e2..412066b 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/tensorneat/examples/func_fit/xor3d_hyperneat.py @@ -28,14 +28,17 @@ if __name__ == '__main__': activation_default=Act.tanh, activation_options=(Act.tanh,), ), + output_transform=Act.tanh, # the activation function for output node in NEAT ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), ), - activation=Act.sigmoid, + activation=Act.tanh, activate_time=10, + output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT ), problem=XOR3d(), generation_limit=300, diff --git a/tensorneat/examples/func_fit/xor_recurrent.py b/tensorneat/examples/func_fit/xor_recurrent.py index 1e1f3bd..41aad28 100644 --- a/tensorneat/examples/func_fit/xor_recurrent.py +++ b/tensorneat/examples/func_fit/xor_recurrent.py @@ -2,8 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d -from utils.activation import ACT_ALL -from utils.aggregation import AGG_ALL +from utils.activation import ACT_ALL, Act if __name__ == '__main__': pipeline = Pipeline( @@ -18,14 +17,21 @@ if __name__ == '__main__': activate_time=5, node_gene=DefaultNodeGene( activation_options=ACT_ALL, - # aggregation_options=AGG_ALL, activation_replace_rate=0.2 ), + output_transform=Act.sigmoid ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000,