diff --git a/README.md b/README.md index 1eac198..3e527e2 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,15 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. ## Requirements -TensorNEAT requires: -- jax (version >= 0.4.16) -- jaxlib (version >= 0.3.0) -- brax [optional] -- gymnax [optional] +Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries: + +- jax (0.4.28) +- jaxlib (0.4.28+cuda12.cudnn89) +- brax (0.10.3) +- gymnax (0.0.8) +We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference. + ## Example Simple Example for XOR problem: ```python diff --git a/recommend_environment.txt b/recommend_environment.txt new file mode 100644 index 0000000..449aae4 --- /dev/null +++ b/recommend_environment.txt @@ -0,0 +1,9 @@ +brax==0.10.3 +flax==0.8.4 +gymnax==0.0.8 +jax==0.4.28 +jaxlib==0.4.28+cuda12.cudnn89 +jaxopt==0.8.3 +mujoco==3.1.4 +mujoco-mjx==3.1.4 +optax==0.2.2 \ No newline at end of file diff --git a/tensorneat/.idea/.gitignore b/tensorneat/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/tensorneat/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index eab7693..302a3c6 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import State, Act, Agg @@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm): activation=Act.sigmoid, aggregation=Agg.sum, activate_time: int = 10, + output_transform: Callable = Act.sigmoid, ): assert substrate.query_coors.shape[1] == neat.num_inputs, \ "Substrate input size should be equal to NEAT input size" @@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm): node_gene=HyperNodeGene(activation, aggregation), conn_gene=HyperNEATConnGene(), activate_time=activate_time, + output_transform=output_transform ) def setup(self, randkey): @@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene): self.activation = activation self.aggregation = aggregation - def forward(self, attrs, inputs): - return self.activation( - self.aggregation(inputs) - ) + def forward(self, attrs, inputs, is_output_node=False): + return jax.lax.cond( + is_output_node, + lambda: self.aggregation(inputs), # output node does not need activation + lambda: self.activation(self.aggregation(inputs)) + ) class HyperNEATConnGene(BaseConnGene): custom_attrs = ['weight'] diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 9f638a2..9f84ac9 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -1,3 +1,3 @@ class BaseCrossover: def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index 4d01b41..deaf745 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -2,6 +2,7 @@ import jax, jax.numpy as jnp from .base import BaseCrossover + class DefaultCrossover(BaseCrossover): def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2): @@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover): # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] # make homologous genes align in nodes2 align with nodes1 - nodes2 = self.align_array(keys1, keys2, nodes2, False) + nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False) # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2)) + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, + self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False)) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] - conns2 = self.align_array(con_keys1, con_keys2, conns2, True) + conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2)) + new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, + self.crossover_gene(randkey_2, conns1, conns2, is_conn=True)) return new_nodes, new_conns @@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover): return refactor_ar2 - def crossover_gene(self, rand_key, g1, g2): - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ + def crossover_gene(self, rand_key, g1, g2, is_conn): r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) + new_gene = jnp.where(r > 0.5, g1, g2) + if is_conn: # fix enabled + enabled = jnp.where( + g1[:, 2] + g2[:, 2] > 0, # any of them is enabled + 1, + 0 + ) + new_gene = new_gene.at[:, 2].set(enabled) + return new_gene diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 33ba6fe..331dbf3 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation): nodes_keys = jax.random.split(k1, num=nodes.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0]) - new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes) - new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns) + new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes) + new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 465050a..2ebfd1b 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene): def __init__(self): super().__init__() - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index a5127e0..c06e7cf 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene): (node1[4] != node2[4]) ) - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): bias, res, act_idx, agg_idx = attrs z = agg(agg_idx, inputs, self.aggregation_options) z = bias + res * z - z = act(act_idx, z, self.activation_options) + + # the last output node should not be activated + z = jax.lax.cond( + is_output_node, + lambda: z, + lambda: act(act_idx, z, self.activation_options) + ) return z diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 768145d..751b52c 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome): if output_transform is not None: try: - aux = output_transform(jnp.zeros(num_outputs)) + _ = output_transform(jnp.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) - - # DONE: Seems like there is a bug in this line - # conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - # modified: exist conn and enable is true - # conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False) - # advanced modified: when and only when enabled is True conn_enable = u_conns[0] == 1 # remove enable attr @@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - # ins = values * weights[:, i] - - z = self.node_gene.forward(nodes_attrs[i], ins) - # z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) - # z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - # z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) - + z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) new_values = values.at[i].set(z) return new_values @@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome): return values # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit) + values = jax.lax.cond( + jnp.isin(i, self.input_idx), + miss, + hit + ) return values, idx + 1 diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 2f4e630..5ed4737 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import unflatten_conns @@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome): node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), activate_time: int = 10, + output_transform: Callable = None ): super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) self.activate_time = activate_time + if output_transform is not None: + try: + _ = output_transform(jnp.zeros(num_outputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + self.output_transform = output_transform + def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) @@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome): )(conns, values) # calculate nodes - values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T) + is_output_nodes = jnp.isin( + jnp.arange(N), + self.output_idx + ) + values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes) return values vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py new file mode 100644 index 0000000..993567d --- /dev/null +++ b/tensorneat/examples/brax/walker.py @@ -0,0 +1,36 @@ +from pipeline import Pipeline +from algorithm.neat import * + +from problem.rl_env import BraxEnv +from utils import Act + +if __name__ == '__main__': + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=17, + num_outputs=6, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=10000, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='walker2d', + ), + generation_limit=10000, + fitness_target=5000 + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 94be087..89c8a7d 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -2,6 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d +from utils import Act if __name__ == '__main__': pipeline = Pipeline( @@ -10,13 +11,25 @@ if __name__ == '__main__': genome=DefaultGenome( num_inputs=3, num_outputs=1, - max_nodes=50, - max_conns=100, + max_nodes=100, + max_conns=200, + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + output_transform=Act.sigmoid, # the activation function for output node ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.01, # magic ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000, diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/tensorneat/examples/func_fit/xor3d_hyperneat.py index 6f0e7e2..412066b 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/tensorneat/examples/func_fit/xor3d_hyperneat.py @@ -28,14 +28,17 @@ if __name__ == '__main__': activation_default=Act.tanh, activation_options=(Act.tanh,), ), + output_transform=Act.tanh, # the activation function for output node in NEAT ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), ), - activation=Act.sigmoid, + activation=Act.tanh, activate_time=10, + output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT ), problem=XOR3d(), generation_limit=300, diff --git a/tensorneat/examples/func_fit/xor_recurrent.py b/tensorneat/examples/func_fit/xor_recurrent.py index 1e1f3bd..41aad28 100644 --- a/tensorneat/examples/func_fit/xor_recurrent.py +++ b/tensorneat/examples/func_fit/xor_recurrent.py @@ -2,8 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d -from utils.activation import ACT_ALL -from utils.aggregation import AGG_ALL +from utils.activation import ACT_ALL, Act if __name__ == '__main__': pipeline = Pipeline( @@ -18,14 +17,21 @@ if __name__ == '__main__': activate_time=5, node_gene=DefaultNodeGene( activation_options=ACT_ALL, - # aggregation_options=AGG_ALL, activation_replace_rate=0.2 ), + output_transform=Act.sigmoid ), pop_size=10000, species_size=10, compatibility_threshold=3.5, + survival_threshold=0.03, ), + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ) ), problem=XOR3d(), generation_limit=10000, diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index e3c73d6..b45424a 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -69,7 +69,7 @@ class Pipeline: pop_transformed ) - fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) + # fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) alg_state = self.algorithm.tell(state.alg, fitnesses) @@ -80,9 +80,12 @@ class Pipeline: def auto_run(self, ini_state): state = ini_state + print("start compile") + tic = time.time() compiled_step = jax.jit(self.step).lower(ini_state).compile() - for w in range(self.generation_limit): + print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) + for _ in range(self.generation_limit): self.generation_timestamp = time.time() @@ -92,11 +95,6 @@ class Pipeline: state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) - for idx, fitnesses_i in enumerate(fitnesses): - if np.isnan(fitnesses_i): - print("Fitness is nan") - print(previous_pop[0][idx], previous_pop[1][idx]) - assert False self.analysis(state, previous_pop, fitnesses) diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 89e1b7c..b9f2329 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -8,34 +8,10 @@ from .. import BaseProblem class RLEnv(BaseProblem): jitable = True - # TODO: move output transform to algorithm def __init__(self, max_step=1000): super().__init__() self.max_step = max_step - # def evaluate(self, randkey, state, act_func, params): - # rng_reset, rng_episode = jax.random.split(randkey) - # init_obs, init_env_state = self.reset(rng_reset) - - # def cond_func(carry): - # _, _, _, done, _ = carry - # return ~done - - # def body_func(carry): - # obs, env_state, rng, _, tr = carry # total reward - # action = act_func(obs, params) - # next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) - # next_rng, _ = jax.random.split(rng) - # return next_obs, next_env_state, next_rng, done, tr + reward - - # _, _, _, _, total_reward = jax.lax.while_loop( - # cond_func, - # body_func, - # (init_obs, init_env_state, rng_episode, False, 0.0) - # ) - - # return total_reward - def evaluate(self, randkey, state, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) @@ -58,6 +34,7 @@ class RLEnv(BaseProblem): ) return total_reward + @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): return self.env_step(randkey, env_state, action)