From d1559317d18b64629fb68a052c31486fcff598df Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 20 May 2024 20:45:48 +0800 Subject: [PATCH 1/4] update readme.md for environment configuration --- README.md | 13 ++++++----- recommend_environment.txt | 9 ++++++++ tensorneat/examples/brax/walker.py | 36 ++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 recommend_environment.txt create mode 100644 tensorneat/examples/brax/walker.py diff --git a/README.md b/README.md index 1eac198..54002dd 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,15 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. ## Requirements -TensorNEAT requires: -- jax (version >= 0.4.16) -- jaxlib (version >= 0.3.0) -- brax [optional] -- gymnax [optional] +Due to the rapid iteration of JAX versions, configuring the runtime environment for tensorNEAT can be challenging. We recommend the following versions for the relevant libraries: + +- jax (0.4.28) +- jaxlib (0.4.28+cuda12.cudnn89) +- brax (0.10.3) +- gymnax (0.0.8) +We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference. + ## Example Simple Example for XOR problem: ```python diff --git a/recommend_environment.txt b/recommend_environment.txt new file mode 100644 index 0000000..449aae4 --- /dev/null +++ b/recommend_environment.txt @@ -0,0 +1,9 @@ +brax==0.10.3 +flax==0.8.4 +gymnax==0.0.8 +jax==0.4.28 +jaxlib==0.4.28+cuda12.cudnn89 +jaxopt==0.8.3 +mujoco==3.1.4 +mujoco-mjx==3.1.4 +optax==0.2.2 \ No newline at end of file diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py new file mode 100644 index 0000000..4cdd645 --- /dev/null +++ b/tensorneat/examples/brax/walker.py @@ -0,0 +1,36 @@ +from pipeline import Pipeline +from algorithm.neat import * + +from problem.rl_env import BraxEnv +from utils import Act + +if __name__ == '__main__': + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=17, + num_outputs=6, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=10, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='walker2d', + ), + generation_limit=10000, + fitness_target=5000 + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file From 33d7821c216a011e001ec68903fe86e5cf6d0fe8 Mon Sep 17 00:00:00 2001 From: WLS2002 <64534280+WLS2002@users.noreply.github.com> Date: Tue, 21 May 2024 10:16:31 +0800 Subject: [PATCH 2/4] Change tensorNEAT to TensorNEAT --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 54002dd..3e527e2 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. ## Requirements -Due to the rapid iteration of JAX versions, configuring the runtime environment for tensorNEAT can be challenging. We recommend the following versions for the relevant libraries: +Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries: - jax (0.4.28) - jaxlib (0.4.28+cuda12.cudnn89) From 6a37563696b065396fa996aaa2725be6859cad53 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 22 May 2024 10:27:32 +0800 Subject: [PATCH 3/4] fix bug in crossover: the child from two normal networks should always be normal. --- tensorneat/.idea/.gitignore | 8 +++++ tensorneat/algorithm/neat/__init__.py | 1 + .../algorithm/neat/ga/crossover/base.py | 2 +- .../algorithm/neat/ga/crossover/default.py | 31 ++++++++++--------- .../algorithm/neat/ga/mutation/default.py | 4 +-- tensorneat/algorithm/neat/gene/node/base.py | 2 +- .../algorithm/neat/gene/node/default.py | 7 +++-- tensorneat/algorithm/neat/genome/default.py | 20 ++++-------- tensorneat/examples/brax/walker.py | 2 +- tensorneat/pipeline.py | 11 +++---- tensorneat/problem/rl_env/rl_jit.py | 1 - 11 files changed, 46 insertions(+), 43 deletions(-) create mode 100644 tensorneat/.idea/.gitignore diff --git a/tensorneat/.idea/.gitignore b/tensorneat/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/tensorneat/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 97185ca..e64148a 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -1,3 +1,4 @@ +from .ga import * from .gene import * from .genome import * from .species import * diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 9f638a2..9f84ac9 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -1,3 +1,3 @@ class BaseCrossover: def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index 4d01b41..deaf745 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -2,6 +2,7 @@ import jax, jax.numpy as jnp from .base import BaseCrossover + class DefaultCrossover(BaseCrossover): def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2): @@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover): # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] # make homologous genes align in nodes2 align with nodes1 - nodes2 = self.align_array(keys1, keys2, nodes2, False) + nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False) # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2)) + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, + self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False)) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] - conns2 = self.align_array(con_keys1, con_keys2, conns2, True) + conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2)) + new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, + self.crossover_gene(randkey_2, conns1, conns2, is_conn=True)) return new_nodes, new_conns @@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover): return refactor_ar2 - def crossover_gene(self, rand_key, g1, g2): - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ + def crossover_gene(self, rand_key, g1, g2, is_conn): r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) + new_gene = jnp.where(r > 0.5, g1, g2) + if is_conn: # fix enabled + enabled = jnp.where( + g1[:, 2] + g2[:, 2] > 0, # any of them is enabled + 1, + 0 + ) + new_gene = new_gene.at[:, 2].set(enabled) + return new_gene diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 33ba6fe..331dbf3 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation): nodes_keys = jax.random.split(k1, num=nodes.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0]) - new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes) - new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns) + new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes) + new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 465050a..2ebfd1b 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene): def __init__(self): super().__init__() - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 3f43d4f..066ec5e 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -85,11 +85,14 @@ class DefaultNodeGene(BaseNodeGene): (node1[4] != node2[4]) ) - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): bias, res, act_idx, agg_idx = attrs z = agg(agg_idx, inputs, self.aggregation_options) z = bias + res * z - z = act(act_idx, z, self.activation_options) + + # the last output node should not be activated + if not is_output_node: + z = act(act_idx, z, self.activation_options) return z diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 768145d..fcc273c 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome): if output_transform is not None: try: - aux = output_transform(jnp.zeros(num_outputs)) + _ = output_transform(jnp.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) - - # DONE: Seems like there is a bug in this line - # conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - # modified: exist conn and enable is true - # conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False) - # advanced modified: when and only when enabled is True conn_enable = u_conns[0] == 1 # remove enable attr @@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - # ins = values * weights[:, i] - z = self.node_gene.forward(nodes_attrs[i], ins) - # z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) - # z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - # z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) - new_values = values.at[i].set(z) return new_values @@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome): return values # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit) + values = jax.lax.cond( + jnp.isin(i, self.input_idx), + miss, + hit + ) return values, idx + 1 diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py index 4cdd645..993567d 100644 --- a/tensorneat/examples/brax/walker.py +++ b/tensorneat/examples/brax/walker.py @@ -18,7 +18,7 @@ if __name__ == '__main__': activation_default=Act.tanh, ) ), - pop_size=10, + pop_size=10000, species_size=10, ), ), diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index ad33945..4d34322 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -69,7 +69,7 @@ class Pipeline: pop_transformed ) - fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) + # fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) alg_state = self.algorithm.tell(state.alg, fitnesses) @@ -80,8 +80,10 @@ class Pipeline: def auto_run(self, ini_state): state = ini_state + print("start compile") + tic = time.time() compiled_step = jax.jit(self.step).lower(ini_state).compile() - + print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) for _ in range(self.generation_limit): self.generation_timestamp = time.time() @@ -91,11 +93,6 @@ class Pipeline: state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) - for idx, fitnesses_i in enumerate(fitnesses): - if np.isnan(fitnesses_i): - print("Fitness is nan") - print(previous_pop[0][idx], previous_pop[1][idx]) - assert False self.analysis(state, previous_pop, fitnesses) diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 128ebfb..b0e5e76 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -8,7 +8,6 @@ from .. import BaseProblem class RLEnv(BaseProblem): jitable = True - # TODO: move output transform to algorithm def __init__(self): super().__init__() From 1fe5d5fca27042e0809462b9c0b5365afe37e87a Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 22 May 2024 11:09:25 +0800 Subject: [PATCH 4/4] disable activation in the output node of network; we recommend to use output_transform; change hyperparameters (strong) in XOR example; --- tensorneat/algorithm/hyperneat/hyperneat.py | 14 ++++++++++---- tensorneat/algorithm/neat/gene/node/default.py | 7 +++++-- tensorneat/algorithm/neat/genome/default.py | 2 +- tensorneat/algorithm/neat/genome/recurrent.py | 16 +++++++++++++++- tensorneat/examples/func_fit/xor.py | 17 +++++++++++++++-- tensorneat/examples/func_fit/xor3d_hyperneat.py | 5 ++++- tensorneat/examples/func_fit/xor_recurrent.py | 12 +++++++++--- 7 files changed, 59 insertions(+), 14 deletions(-) diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index eab7693..302a3c6 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import State, Act, Agg @@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm): activation=Act.sigmoid, aggregation=Agg.sum, activate_time: int = 10, + output_transform: Callable = Act.sigmoid, ): assert substrate.query_coors.shape[1] == neat.num_inputs, \ "Substrate input size should be equal to NEAT input size" @@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm): node_gene=HyperNodeGene(activation, aggregation), conn_gene=HyperNEATConnGene(), activate_time=activate_time, + output_transform=output_transform ) def setup(self, randkey): @@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene): self.activation = activation self.aggregation = aggregation - def forward(self, attrs, inputs): - return self.activation( - self.aggregation(inputs) - ) + def forward(self, attrs, inputs, is_output_node=False): + return jax.lax.cond( + is_output_node, + lambda: self.aggregation(inputs), # output node does not need activation + lambda: self.activation(self.aggregation(inputs)) + ) class HyperNEATConnGene(BaseConnGene): custom_attrs = ['weight'] diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 066ec5e..0b90f2c 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene): z = bias + res * z # the last output node should not be activated - if not is_output_node: - z = act(act_idx, z, self.activation_options) + z = jax.lax.cond( + is_output_node, + lambda: z, + lambda: act(act_idx, z, self.activation_options) + ) return z diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index fcc273c..751b52c 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - z = self.node_gene.forward(nodes_attrs[i], ins) + z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) new_values = values.at[i].set(z) return new_values diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 2f4e630..5ed4737 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import unflatten_conns @@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome): node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), activate_time: int = 10, + output_transform: Callable = None ): super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) self.activate_time = activate_time + if output_transform is not None: + try: + _ = output_transform(jnp.zeros(num_outputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + self.output_transform = output_transform + def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) @@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome): )(conns, values) # calculate nodes - values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T) + is_output_nodes = jnp.isin( + jnp.arange(N), + self.output_idx + ) + values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes) return values vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 94be087..89c8a7d 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -2,6 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d +from utils import Act if __name__ == '__main__': pipeline = Pipeline( @@ -10,13 +11,25 @@ if __name__ == '__main__': genome=DefaultGenome( num_inputs=3, num_outputs=1, - max_nodes=50, - max_conns=100, + max_nodes=100, + max_conns=200, + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + output_transform=Act.sigmoid, # the activation function for output node ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.01, # magic ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000, diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/tensorneat/examples/func_fit/xor3d_hyperneat.py index 6f0e7e2..412066b 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/tensorneat/examples/func_fit/xor3d_hyperneat.py @@ -28,14 +28,17 @@ if __name__ == '__main__': activation_default=Act.tanh, activation_options=(Act.tanh,), ), + output_transform=Act.tanh, # the activation function for output node in NEAT ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), ), - activation=Act.sigmoid, + activation=Act.tanh, activate_time=10, + output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT ), problem=XOR3d(), generation_limit=300, diff --git a/tensorneat/examples/func_fit/xor_recurrent.py b/tensorneat/examples/func_fit/xor_recurrent.py index 1e1f3bd..41aad28 100644 --- a/tensorneat/examples/func_fit/xor_recurrent.py +++ b/tensorneat/examples/func_fit/xor_recurrent.py @@ -2,8 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d -from utils.activation import ACT_ALL -from utils.aggregation import AGG_ALL +from utils.activation import ACT_ALL, Act if __name__ == '__main__': pipeline = Pipeline( @@ -18,14 +17,21 @@ if __name__ == '__main__': activate_time=5, node_gene=DefaultNodeGene( activation_options=ACT_ALL, - # aggregation_options=AGG_ALL, activation_replace_rate=0.2 ), + output_transform=Act.sigmoid ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000,