diff --git a/examples/brax/halfcheetah.py b/examples/brax/halfcheetah.py index 7886e31..2827324 100644 --- a/examples/brax/halfcheetah.py +++ b/examples/brax/halfcheetah.py @@ -20,16 +20,16 @@ if __name__ == "__main__": survival_threshold=0.1, compatibility_threshold=1.0, genome=DefaultGenome( - max_nodes=100, + max_nodes=50, max_conns=200, num_inputs=17, num_outputs=6, init_hidden_layers=(), node_gene=BiasNode( - activation_options=ACT.tanh, + activation_options=ACT.scaled_tanh, aggregation_options=AGG.sum, ), - output_transform=ACT.standard_tanh, + output_transform=ACT.tanh, ), ), problem=BraxEnv( diff --git a/examples/brax/walker2d.py b/examples/brax/walker2d.py index 48e1ab7..17ad745 100644 --- a/examples/brax/walker2d.py +++ b/examples/brax/walker2d.py @@ -20,7 +20,7 @@ if __name__ == "__main__": survival_threshold=0.1, compatibility_threshold=1.0, genome=DefaultGenome( - max_nodes=100, + max_nodes=50, max_conns=200, num_inputs=17, num_outputs=6, @@ -29,7 +29,7 @@ if __name__ == "__main__": activation_options=ACT.tanh, aggregation_options=AGG.sum, ), - output_transform=ACT.standard_tanh, + output_transform=ACT.tanh, ), ), problem=BraxEnv( diff --git a/examples/func_fit/custom_func_fit.py b/examples/func_fit/custom_func_fit.py index 1088c86..9bbea1d 100644 --- a/examples/func_fit/custom_func_fit.py +++ b/examples/func_fit/custom_func_fit.py @@ -6,7 +6,7 @@ from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasN from tensorneat.problem.func_fit import CustomFuncFit from tensorneat.common import ACT, AGG - +# define a custom function fit problem def pagie_polynomial(inputs): x, y = inputs res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4)) @@ -14,9 +14,12 @@ def pagie_polynomial(inputs): # important! returns an array, NOT a scalar return jnp.array([res]) +# define custom activate function and register it +def square(x): + return x ** 2 +ACT.add_func("square", square) if __name__ == "__main__": - custom_problem = CustomFuncFit( func=pagie_polynomial, low_bounds=[-1, -1], diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index c575985..141f163 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -14,7 +14,7 @@ if __name__ == "__main__": num_inputs=3, num_outputs=1, init_hidden_layers=(), - output_transform=ACT.standard_sigmoid, + output_transform=ACT.sigmoid, ), ), problem=XOR3d(), diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py index bcdd816..ca716b8 100644 --- a/examples/func_fit/xor_hyperneat.py +++ b/examples/func_fit/xor_hyperneat.py @@ -22,12 +22,12 @@ if __name__ == "__main__": num_inputs=4, # size of query coors num_outputs=1, init_hidden_layers=(), - output_transform=ACT.standard_tanh, + output_transform=ACT.tanh, ), ), activation=ACT.tanh, activate_time=10, - output_transform=ACT.standard_sigmoid, + output_transform=ACT.sigmoid, ), problem=XOR3d(), generation_limit=300, diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py index d878e9e..62dcfc0 100644 --- a/examples/func_fit/xor_recurrent.py +++ b/examples/func_fit/xor_recurrent.py @@ -14,7 +14,7 @@ if __name__ == "__main__": num_inputs=3, num_outputs=1, init_hidden_layers=(), - output_transform=ACT.standard_sigmoid, + output_transform=ACT.sigmoid, activate_time=10, ), ), diff --git a/examples/gymnax/cartpole_hyperneat.py b/examples/gymnax/cartpole_hyperneat.py index e47fa3e..55e8e77 100644 --- a/examples/gymnax/cartpole_hyperneat.py +++ b/examples/gymnax/cartpole_hyperneat.py @@ -27,7 +27,7 @@ if __name__ == "__main__": num_inputs=4, # size of query coors num_outputs=1, init_hidden_layers=(), - output_transform=ACT.standard_tanh, + output_transform=ACT.tanh, ), ), activation=ACT.tanh, diff --git a/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py index 919ccff..5401566 100644 --- a/examples/gymnax/mountain_car_continuous.py +++ b/examples/gymnax/mountain_car_continuous.py @@ -24,7 +24,7 @@ if __name__ == "__main__": activation_options=ACT.tanh, aggregation_options=AGG.sum, ), - output_transform=ACT.standard_tanh, + output_transform=ACT.tanh, ), ), problem=GymNaxEnv( diff --git a/examples/interpret_visualize/genome_sympy.py b/examples/interpret_visualize/genome_sympy.py index 73d0c7f..30c8da3 100644 --- a/examples/interpret_visualize/genome_sympy.py +++ b/examples/interpret_visualize/genome_sympy.py @@ -1,16 +1,16 @@ import jax, jax.numpy as jnp -from algorithm.neat import * -from algorithm.neat.genome.dense import DenseInitialize -from utils.graph import topological_sort_python +from tensorneat.genome import DefaultGenome from tensorneat.common import * +from tensorneat.common.functions import SympySigmoid if __name__ == "__main__": - genome = DenseInitialize( + genome = DefaultGenome( num_inputs=3, num_outputs=1, max_nodes=50, max_conns=500, + output_transform=ACT.sigmoid, ) state = genome.setup() @@ -22,7 +22,7 @@ if __name__ == "__main__": input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx() - res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid) + res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999*x, sympy_output_transform=SympySigmoid) (symbols, args_symbols, input_symbols, @@ -35,3 +35,11 @@ if __name__ == "__main__": inputs = jnp.zeros(3) print(forward_func(inputs)) + + print(genome.forward(state, genome.transform(state, nodes, conns), inputs)) + + print(AGG.sympy_module("jax")) + print(AGG.sympy_module("numpy")) + + print(ACT.sympy_module("jax")) + print(ACT.sympy_module("numpy")) \ No newline at end of file diff --git a/examples/with_evox/example.py b/examples/with_evox/walker2d_evox.py similarity index 52% rename from examples/with_evox/example.py rename to examples/with_evox/walker2d_evox.py index 4d4f5e4..263ba1c 100644 --- a/examples/with_evox/example.py +++ b/examples/with_evox/walker2d_evox.py @@ -1,29 +1,29 @@ import jax import jax.numpy as jnp -from evox import workflows, algorithms, problems +from evox import workflows, problems -from tensorneat.examples.with_evox.evox_algorithm_adaptor import EvoXAlgorithmAdaptor -from tensorneat.examples.with_evox.tensorneat_monitor import TensorNEATMonitor +from tensorneat.common.evox_adaptors import EvoXAlgorithmAdaptor, TensorNEATMonitor from tensorneat.algorithm import NEAT -from tensorneat.algorithm.neat import DefaultSpecies, DefaultGenome, DefaultNodeGene -from tensorneat.common import ACT +from tensorneat.genome import DefaultGenome, BiasNode +from tensorneat.common import ACT, AGG neat_algorithm = NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=17, - num_outputs=6, - max_nodes=200, - max_conns=500, - node_gene=DefaultNodeGene( - activation_options=(ACT.standard_tanh,), - activation_default=ACT.standard_tanh, - ), - output_transform=ACT.tanh, + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + max_nodes=50, + max_conns=200, + num_inputs=17, + num_outputs=6, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=ACT.tanh, + aggregation_options=AGG.sum, ), - pop_size=10000, - species_size=10, + output_transform=ACT.tanh, ), ) evox_algorithm = EvoXAlgorithmAdaptor(neat_algorithm) @@ -37,12 +37,13 @@ problem = problems.neuroevolution.Brax( policy=evox_algorithm.forward, max_episode_length=1000, num_episodes=1, - backend="mjx" ) + def nan2inf(x): return jnp.where(jnp.isnan(x), -jnp.inf, x) + # create a workflow workflow = workflows.StdWorkflow( algorithm=evox_algorithm, @@ -55,11 +56,11 @@ workflow = workflows.StdWorkflow( # init the workflow state = workflow.init(workflow_key) -# state = workflow.enable_multi_devices(state) -# run the workflow for 100 steps -import time +# enable multi devices +state = workflow.enable_multi_devices(state) + +# run the workflow for 100 steps for i in range(100): - tic = time.time() train_info, state = workflow.step(state) - monitor.show() \ No newline at end of file + monitor.show() diff --git a/pyproject.toml b/pyproject.toml index 21fa98c..2a69b75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ dependencies = [ "flax >= 0.8.4", "mujoco >= 3.1.4", "mujoco-mjx >= 3.1.4", + "networkx >= 3.3", + "matplotlib >= 3.9.0", + "sympy >= 1.12.1", ] [project.urls] diff --git a/src/tensorneat.egg-info/PKG-INFO b/src/tensorneat.egg-info/PKG-INFO deleted file mode 100644 index 01b41bf..0000000 --- a/src/tensorneat.egg-info/PKG-INFO +++ /dev/null @@ -1,180 +0,0 @@ -Metadata-Version: 2.1 -Name: tensorneat -Version: 0.1.0 -Summary: tensorneat -Author-email: Lishuang Wang -License: BSD 3-Clause License - - Copyright (c) 2024, EMI-Group - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -Project-URL: Homepage, https://github.com/EMI-Group/tensorneat -Project-URL: Bug Tracker, https://github.com/EMI-Group/tensorneat/issues -Classifier: Programming Language :: Python :: 3 -Classifier: License :: OSI Approved :: BSD License -Classifier: Intended Audience :: Science/Research -Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence -Requires-Python: >=3.9 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: brax>=0.10.3 -Requires-Dist: jax>=0.4.28 -Requires-Dist: gymnax>=0.0.8 -Requires-Dist: jaxopt>=0.8.3 -Requires-Dist: optax>=0.2.2 -Requires-Dist: flax>=0.8.4 -Requires-Dist: mujoco>=3.1.4 -Requires-Dist: mujoco-mjx>=3.1.4 - -

- - - - - EvoX Logo - - -
-

- -

-🌟 TensorNEAT: Tensorized NEAT Implementation in JAX 🌟 -

- -

- - TensorNEAT Paper on arXiv - -

- -## Introduction -TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. - -## Requirements -Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries: - -- jax (0.4.28) -- jaxlib (0.4.28+cuda12.cudnn89) -- brax (0.10.3) -- gymnax (0.0.8) - -We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference. - -## Example -Simple Example for XOR problem: -```python -from pipeline import Pipeline -from algorithm.neat import * - -from problem.func_fit import XOR3d - -if __name__ == '__main__': - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=3, - num_outputs=1, - max_nodes=50, - max_conns=100, - ), - pop_size=10000, - species_size=10, - compatibility_threshold=3.5, - ), - ), - problem=XOR3d(), - generation_limit=10000, - fitness_target=-1e-8 - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) - # show result - pipeline.show(state, best) -``` - -Simple Example for RL envs in Brax (Ant): -```python -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import BraxEnv -from tensorneat.utils import ACT - -if __name__ == '__main__': - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=27, - num_outputs=8, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(ACT.tanh,), - activation_default=ACT.tanh, - ) - ), - pop_size=1000, - species_size=10, - ), - ), - problem=BraxEnv( - env_name='ant', - ), - generation_limit=10000, - fitness_target=5000 - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) -``` - -more examples are in `tensorneat/examples`. - -## Community & Support - -- Engage in discussions and share your experiences on [GitHub Discussion Board](https://github.com/EMI-Group/evox/discussions). -- Join our QQ group (ID: 297969717). - -## Citing TensorNEAT - -If you use TensorNEAT in your research and want to cite it in your work, please use: -``` -@article{tensorneat, - title = {{Tensorized} {NeuroEvolution} of {Augmenting} {Topologies} for {GPU} {Acceleration}}, - author = {Wang, Lishuang and Zhao, Mengfei and Liu, Enyu and Sun, Kebin and Cheng, Ran}, - booktitle = {Proceedings of the Genetic and Evolutionary Computation Conference (GECCO)}, - year = {2024} -} diff --git a/src/tensorneat.egg-info/SOURCES.txt b/src/tensorneat.egg-info/SOURCES.txt deleted file mode 100644 index ab4684f..0000000 --- a/src/tensorneat.egg-info/SOURCES.txt +++ /dev/null @@ -1,69 +0,0 @@ -LICENSE -README.md -pyproject.toml -src/tensorneat/pipeline.py -src/tensorneat.egg-info/PKG-INFO -src/tensorneat.egg-info/SOURCES.txt -src/tensorneat.egg-info/dependency_links.txt -src/tensorneat.egg-info/requires.txt -src/tensorneat.egg-info/top_level.txt -src/tensorneat/algorithm/__init__.py -src/tensorneat/algorithm/base.py -src/tensorneat/algorithm/hyperneat/__init__.py -src/tensorneat/algorithm/hyperneat/hyperneat.py -src/tensorneat/algorithm/hyperneat/substrate/__init__.py -src/tensorneat/algorithm/hyperneat/substrate/base.py -src/tensorneat/algorithm/hyperneat/substrate/default.py -src/tensorneat/algorithm/hyperneat/substrate/full.py -src/tensorneat/algorithm/neat/__init__.py -src/tensorneat/algorithm/neat/neat.py -src/tensorneat/algorithm/neat/species.py -src/tensorneat/common/__init__.py -src/tensorneat/common/graph.py -src/tensorneat/common/state.py -src/tensorneat/common/stateful_class.py -src/tensorneat/common/tools.py -src/tensorneat/common/functions/__init__.py -src/tensorneat/common/functions/act_jnp.py -src/tensorneat/common/functions/act_sympy.py -src/tensorneat/common/functions/agg_jnp.py -src/tensorneat/common/functions/agg_sympy.py -src/tensorneat/common/functions/manager.py -src/tensorneat/genome/__init__.py -src/tensorneat/genome/base.py -src/tensorneat/genome/default.py -src/tensorneat/genome/recurrent.py -src/tensorneat/genome/utils.py -src/tensorneat/genome/gene/__init__.py -src/tensorneat/genome/gene/base.py -src/tensorneat/genome/gene/conn/__init__.py -src/tensorneat/genome/gene/conn/base.py -src/tensorneat/genome/gene/conn/default.py -src/tensorneat/genome/gene/node/__init__.py -src/tensorneat/genome/gene/node/base.py -src/tensorneat/genome/gene/node/bias.py -src/tensorneat/genome/gene/node/default.py -src/tensorneat/genome/operations/__init__.py -src/tensorneat/genome/operations/crossover/__init__.py -src/tensorneat/genome/operations/crossover/base.py -src/tensorneat/genome/operations/crossover/default.py -src/tensorneat/genome/operations/distance/__init__.py -src/tensorneat/genome/operations/distance/base.py -src/tensorneat/genome/operations/distance/default.py -src/tensorneat/genome/operations/mutation/__init__.py -src/tensorneat/genome/operations/mutation/base.py -src/tensorneat/genome/operations/mutation/default.py -src/tensorneat/problem/__init__.py -src/tensorneat/problem/base.py -src/tensorneat/problem/func_fit/__init__.py -src/tensorneat/problem/func_fit/custom.py -src/tensorneat/problem/func_fit/func_fit.py -src/tensorneat/problem/func_fit/xor.py -src/tensorneat/problem/func_fit/xor3d.py -src/tensorneat/problem/rl/__init__.py -src/tensorneat/problem/rl/brax.py -src/tensorneat/problem/rl/gymnax.py -src/tensorneat/problem/rl/rl_jit.py -test/test_genome.py -test/test_nan_fitness.py -test/test_record_episode.py \ No newline at end of file diff --git a/src/tensorneat.egg-info/dependency_links.txt b/src/tensorneat.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/tensorneat.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/tensorneat.egg-info/requires.txt b/src/tensorneat.egg-info/requires.txt deleted file mode 100644 index e5517ba..0000000 --- a/src/tensorneat.egg-info/requires.txt +++ /dev/null @@ -1,8 +0,0 @@ -brax>=0.10.3 -jax>=0.4.28 -gymnax>=0.0.8 -jaxopt>=0.8.3 -optax>=0.2.2 -flax>=0.8.4 -mujoco>=3.1.4 -mujoco-mjx>=3.1.4 diff --git a/src/tensorneat.egg-info/top_level.txt b/src/tensorneat.egg-info/top_level.txt deleted file mode 100644 index afde8a9..0000000 --- a/src/tensorneat.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -tensorneat diff --git a/src/tensorneat/algorithm/hyperneat/hyperneat.py b/src/tensorneat/algorithm/hyperneat/hyperneat.py index b0744fb..137ca05 100644 --- a/src/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/src/tensorneat/algorithm/hyperneat/hyperneat.py @@ -19,7 +19,7 @@ class HyperNEAT(BaseAlgorithm): aggregation: Callable = AGG.sum, activation: Callable = ACT.sigmoid, activate_time: int = 10, - output_transform: Callable = ACT.standard_sigmoid, + output_transform: Callable = ACT.sigmoid, ): assert ( substrate.query_coors.shape[1] == neat.num_inputs diff --git a/src/tensorneat/common/__init__.py b/src/tensorneat/common/__init__.py index 5d75c2a..4b29928 100644 --- a/src/tensorneat/common/__init__.py +++ b/src/tensorneat/common/__init__.py @@ -3,4 +3,4 @@ from .graph import * from .state import State from .stateful_class import StatefulBaseClass -from .functions import ACT, AGG, apply_activation, apply_aggregation +from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name diff --git a/src/tensorneat/common/evox_adaptors/__init__.py b/src/tensorneat/common/evox_adaptors/__init__.py new file mode 100644 index 0000000..4f8c31c --- /dev/null +++ b/src/tensorneat/common/evox_adaptors/__init__.py @@ -0,0 +1,2 @@ +from .algorithm_adaptor import EvoXAlgorithmAdaptor +from .tensorneat_monitor import TensorNEATMonitor diff --git a/examples/with_evox/evox_algorithm_adaptor.py b/src/tensorneat/common/evox_adaptors/algorithm_adaptor.py similarity index 100% rename from examples/with_evox/evox_algorithm_adaptor.py rename to src/tensorneat/common/evox_adaptors/algorithm_adaptor.py diff --git a/examples/with_evox/tensorneat_monitor.py b/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py similarity index 62% rename from examples/with_evox/tensorneat_monitor.py rename to src/tensorneat/common/evox_adaptors/tensorneat_monitor.py index 05261fe..c9bc9b3 100644 --- a/examples/with_evox/tensorneat_monitor.py +++ b/src/tensorneat/common/evox_adaptors/tensorneat_monitor.py @@ -16,12 +16,12 @@ class TensorNEATMonitor(Monitor): def __init__( self, - neat_algorithm: TensorNEATAlgorithm, + tensorneat_algorithm: TensorNEATAlgorithm, save_dir: str = None, is_save: bool = False, ): super().__init__() - self.neat_algorithm = neat_algorithm + self.tensorneat_algorithm = tensorneat_algorithm self.generation_timestamp = time.time() self.alg_state: TensorNEATState = None @@ -60,7 +60,9 @@ class TensorNEATMonitor(Monitor): self.fitness = jax.device_get(fitness) def show(self): - pop = self.neat_algorithm.ask(self.alg_state) + pop = self.tensorneat_algorithm.ask(self.alg_state) + generation = int(self.alg_state.generation) + valid_fitnesses = self.fitness[~np.isinf(self.fitness)] max_f, min_f, mean_f, std_f = ( @@ -73,22 +75,20 @@ class TensorNEATMonitor(Monitor): new_timestamp = time.time() cost_time = new_timestamp - self.generation_timestamp - self.generation_timestamp = new_timestamp - + self.generation_timestamp = time.time() + max_idx = np.argmax(self.fitness) if self.fitness[max_idx] > self.best_fitness: self.best_fitness = self.fitness[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx] if self.is_save: + # save best best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) - with open( - os.path.join( - self.genome_dir, - f"{int(self.neat_algorithm.generation(self.alg_state))}.npz", - ), - "wb", - ) as f: + file_name = os.path.join( + self.genome_dir, f"{generation}.npz" + ) + with open(file_name, "wb") as f: np.savez( f, nodes=best_genome[0], @@ -96,38 +96,15 @@ class TensorNEATMonitor(Monitor): fitness=self.best_fitness, ) - # save best if save path is not None - member_count = jax.device_get(self.neat_algorithm.member_count(self.alg_state)) - species_sizes = [int(i) for i in member_count if i > 0] - - pop = jax.device_get(pop) - pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL) - nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,) - conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,) - - max_node_cnt, min_node_cnt, mean_node_cnt = ( - max(nodes_cnt), - min(nodes_cnt), - np.mean(nodes_cnt), - ) - - max_conn_cnt, min_conn_cnt, mean_conn_cnt = ( - max(conns_cnt), - min(conns_cnt), - np.mean(conns_cnt), - ) + # append log + with open(os.path.join(self.save_dir, "log.txt"), "a") as f: + f.write( + f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" + ) print( - f"Generation: {self.neat_algorithm.generation(self.alg_state)}, Cost time: {cost_time * 1000:.2f}ms\n", - f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n", - f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n", - f"\tspecies: {len(species_sizes)}, {species_sizes}\n", + f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", ) - # append log - if self.is_save: - with open(os.path.join(self.save_dir, "log.txt"), "a") as f: - f.write( - f"{self.neat_algorithm.generation(self.alg_state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" - ) + self.tensorneat_algorithm.show_details(self.alg_state, self.fitness) \ No newline at end of file diff --git a/src/tensorneat/common/functions/__init__.py b/src/tensorneat/common/functions/__init__.py index 6e23462..7c51ea0 100644 --- a/src/tensorneat/common/functions/__init__.py +++ b/src/tensorneat/common/functions/__init__.py @@ -1,3 +1,5 @@ +import jax, jax.numpy as jnp + from .act_jnp import * from .act_sympy import * from .agg_jnp import * @@ -32,6 +34,7 @@ act_name2sympy = { "log": SympyLog, "exp": SympyExp, "abs": SympyAbs, + "clip": SympyClip, } agg_name2jnp = { @@ -40,7 +43,6 @@ agg_name2jnp = { "max": max_, "min": min_, "maxabs": maxabs_, - "median": median_, "mean": mean_, } @@ -50,9 +52,42 @@ agg_name2sympy = { "max": SympyMax, "min": SympyMin, "maxabs": SympyMaxabs, - "median": SympyMedian, "mean": SympyMean, } ACT = FunctionManager(act_name2jnp, act_name2sympy) AGG = FunctionManager(agg_name2jnp, agg_name2sympy) + +def apply_activation(idx, z, act_funcs): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + + # -1 means identity activation + res = jax.lax.cond( + idx == -1, + lambda: z, + lambda: jax.lax.switch(idx, act_funcs, z), + ) + + return res + +def apply_aggregation(idx, z, agg_funcs): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + + return jax.lax.cond( + jnp.all(jnp.isnan(z)), + lambda: jnp.nan, # all inputs are nan + lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise + ) + +def get_func_name(func): + name = func.__name__ + if name.endswith("_"): + name = name[:-1] + return name \ No newline at end of file diff --git a/src/tensorneat/common/functions/act_jnp.py b/src/tensorneat/common/functions/act_jnp.py index 75b0e86..e13f440 100644 --- a/src/tensorneat/common/functions/act_jnp.py +++ b/src/tensorneat/common/functions/act_jnp.py @@ -1,6 +1,6 @@ import jax.numpy as jnp -SCALE = 5 +SCALE = 3 def scaled_sigmoid_(z): diff --git a/src/tensorneat/common/functions/act_sympy.py b/src/tensorneat/common/functions/act_sympy.py index 1c80594..317d50a 100644 --- a/src/tensorneat/common/functions/act_sympy.py +++ b/src/tensorneat/common/functions/act_sympy.py @@ -1,7 +1,7 @@ import sympy as sp import numpy as np -SCALE = 5 +SCALE = 3 class SympySigmoid(sp.Function): @classmethod diff --git a/src/tensorneat/common/functions/agg_jnp.py b/src/tensorneat/common/functions/agg_jnp.py index 53ca931..38fe8ab 100644 --- a/src/tensorneat/common/functions/agg_jnp.py +++ b/src/tensorneat/common/functions/agg_jnp.py @@ -23,18 +23,6 @@ def maxabs_(z): max_abs_index = jnp.argmax(abs_z) return z[max_abs_index] - -def median_(z): - n = jnp.sum(~jnp.isnan(z), axis=0) - - z = jnp.sort(z) # sort - - idx1, idx2 = (n - 1) // 2, n // 2 - median = (z[idx1] + z[idx2]) / 2 - - return median - - def mean_(z): sumation = sum_(z) valid_count = jnp.sum(~jnp.isnan(z), axis=0) diff --git a/src/tensorneat/common/functions/agg_sympy.py b/src/tensorneat/common/functions/agg_sympy.py index 85df179..35389d8 100644 --- a/src/tensorneat/common/functions/agg_sympy.py +++ b/src/tensorneat/common/functions/agg_sympy.py @@ -7,30 +7,18 @@ class SympySum(sp.Function): def eval(cls, z): return sp.Add(*z) - @classmethod - def numerical_eval(cls, z, backend=np): - return backend.sum(z) - class SympyProduct(sp.Function): @classmethod def eval(cls, z): return sp.Mul(*z) - @classmethod - def numerical_eval(cls, z, backend=np): - return backend.product(z) - class SympyMax(sp.Function): @classmethod def eval(cls, z): return sp.Max(*z) - @classmethod - def numerical_eval(cls, z, backend=np): - return backend.max(z) - class SympyMin(sp.Function): @classmethod @@ -48,26 +36,3 @@ class SympyMean(sp.Function): @classmethod def eval(cls, z): return sp.Add(*z) / len(z) - - -class SympyMedian(sp.Function): - @classmethod - def eval(cls, args): - - if all(arg.is_number for arg in args): - sorted_args = sorted(args) - n = len(sorted_args) - if n % 2 == 1: - return sorted_args[n // 2] - else: - return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2 - - return None - - def _sympystr(self, printer): - return f"median({', '.join(map(str, self.args))})" - - def _latex(self, printer): - return ( - r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)" - ) diff --git a/src/tensorneat/common/functions/manager.py b/src/tensorneat/common/functions/manager.py index cc8b2cc..33a69f6 100644 --- a/src/tensorneat/common/functions/manager.py +++ b/src/tensorneat/common/functions/manager.py @@ -1,28 +1,32 @@ +from functools import partial +import numpy as np +import jax.numpy as jnp from typing import Union, Callable import sympy as sp + class FunctionManager: def __init__(self, name2jnp, name2sympy): self.name2jnp = name2jnp self.name2sympy = name2sympy + for name, func in name2jnp.items(): + setattr(self, name, func) def get_all_funcs(self): all_funcs = [] - for name in self.names: + for name in self.name2jnp: all_funcs.append(getattr(self, name)) return all_funcs - def __getattribute__(self, name: str): - return self.name2jnp[name] - def add_func(self, name, func): if not callable(func): raise ValueError("The provided function is not callable") - if name in self.names: + if name in self.name2jnp: raise ValueError(f"The provided name={name} is already in use") self.name2jnp[name] = func + setattr(self, name, func) def update_sympy(self, name, sympy_cls: sp.Function): self.name2sympy[name] = sympy_cls @@ -47,3 +51,16 @@ class FunctionManager: if name not in self.name2sympy: raise ValueError(f"Func {name} doesn't have a sympy representation.") return self.name2sympy[name] + + def sympy_module(self, backend: str): + assert backend in ["jax", "numpy"] + if backend == "jax": + backend = jnp + elif backend == "numpy": + backend = np + module = {} + for sympy_cls in self.name2sympy.values(): + if hasattr(sympy_cls, "numerical_eval"): + module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend) + + return module diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index e8a218d..2464fc9 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -15,8 +15,8 @@ from tensorneat.common import ( topological_sort_python, I_INF, attach_with_inf, - SYMPY_FUNCS_MODULE_NP, - SYMPY_FUNCS_MODULE_JNP, + ACT, + AGG ) @@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome): def otherwise(): # calculate connections conn_indices = u_conns[:, i] - hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs + hit_attrs = attach_with_inf( + conns_attrs, conn_indices + ) # fetch conn attrs ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( state, hit_attrs, values ) @@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome): state, nodes_attrs[i], ins, - is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes + is_output_node=jnp.isin( + nodes[i, 0], self.output_idx + ), # nodes[0] -> the key of nodes ) # set new value @@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome): ): assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" - module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP if sympy_input_transform is None and self.input_transform is not None: warnings.warn( @@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome): sp.lambdify( input_symbols + list(args_symbols.keys()), exprs, - modules=[backend, module], + modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)], ) for exprs in output_exprs ] @@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome): rotate=0, reverse_node_order=False, size=(300, 300, 300), - color=("blue", "blue", "blue"), + color=("yellow", "white", "blue"), + with_labels=False, + edgecolors="k", + arrowstyle="->", + arrowsize=3, + edge_color=(0.3, 0.3, 0.3), save_path="network.svg", save_dpi=800, **kwargs, @@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome): import networkx as nx from matplotlib import pyplot as plt - nodes_list = list(network["nodes"]) conns_list = list(network["conns"]) input_idx = self.get_input_idx() output_idx = self.get_output_idx() @@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome): pos=rotated_pos, node_size=node_sizes, node_color=node_colors, + with_labels=with_labels, + edgecolors=edgecolors, + arrowstyle=arrowstyle, + arrowsize=arrowsize, + edge_color=edge_color, **kwargs, ) plt.savefig(save_path, dpi=save_dpi) diff --git a/src/tensorneat/genome/gene/node/bias.py b/src/tensorneat/genome/gene/node/bias.py index b28eada..37f5ae5 100644 --- a/src/tensorneat/genome/gene/node/bias.py +++ b/src/tensorneat/genome/gene/node/bias.py @@ -10,7 +10,7 @@ from tensorneat.common import ( apply_aggregation, mutate_int, mutate_float, - convert_to_sympy, + get_func_name ) from . import BaseNode @@ -141,8 +141,8 @@ class BiasNode(BaseNode): self.__class__.__name__, idx, bias, - self.aggregation_options[agg].__name__, - act_func.__name__, + get_func_name(self.aggregation_options[agg]), + get_func_name(act_func), idx_width=idx_width, float_width=precision + 3, func_width=func_width, @@ -165,21 +165,19 @@ class BiasNode(BaseNode): return { "idx": idx, "bias": bias, - "agg": self.aggregation_options[int(agg)].__name__, - "act": act_func.__name__, + "agg": get_func_name(self.aggregation_options[agg]), + "act": get_func_name(act_func), } def sympy_func(self, state, node_dict, inputs, is_output_node=False): - nd = node_dict + bias = sp.symbols(f"n_{node_dict['idx']}_b") - bias = sp.symbols(f"n_{nd['idx']}_b") - - z = convert_to_sympy(nd["agg"])(inputs) + z = AGG.obtain_sympy(node_dict["agg"])(inputs) z = bias + z if is_output_node: pass else: - z = convert_to_sympy(nd["act"])(z) + z = ACT.obtain_sympy(node_dict["act"])(z) - return z, {bias: nd["bias"]} + return z, {bias: node_dict["bias"]} diff --git a/src/tensorneat/genome/gene/node/default.py b/src/tensorneat/genome/gene/node/default.py index dc1219a..855b071 100644 --- a/src/tensorneat/genome/gene/node/default.py +++ b/src/tensorneat/genome/gene/node/default.py @@ -11,7 +11,7 @@ from tensorneat.common import ( apply_aggregation, mutate_int, mutate_float, - convert_to_sympy, + get_func_name ) from .base import BaseNode @@ -176,8 +176,8 @@ class DefaultNode(BaseNode): idx, bias, res, - self.aggregation_options[agg].__name__, - act_func.__name__, + get_func_name(self.aggregation_options[agg]), + get_func_name(act_func), idx_width=idx_width, float_width=precision + 3, func_width=func_width, @@ -200,8 +200,8 @@ class DefaultNode(BaseNode): "idx": idx, "bias": bias, "res": res, - "agg": self.aggregation_options[int(agg)].__name__, - "act": act_func.__name__, + "agg": get_func_name(self.aggregation_options[agg]), + "act": get_func_name(act_func), } def sympy_func(self, state, node_dict, inputs, is_output_node=False): @@ -209,12 +209,13 @@ class DefaultNode(BaseNode): bias = sp.symbols(f"n_{nd['idx']}_b") res = sp.symbols(f"n_{nd['idx']}_r") - z = convert_to_sympy(nd["agg"])(inputs) + print(nd["agg"]) + z = AGG.obtain_sympy(nd["agg"])(inputs) z = bias + res * z if is_output_node: pass else: - z = convert_to_sympy(nd["act"])(z) + z = ACT.obtain_sympy(nd["act"])(z) return z, {bias: nd["bias"], res: nd["res"]} diff --git a/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb b/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb new file mode 100644 index 0000000..363fcab --- /dev/null +++ b/tutorials/.ipynb_checkpoints/tutorial-01-genome-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/tutorial-01-genome.ipynb b/tutorials/tutorial-01-genome.ipynb new file mode 100644 index 0000000..d0dbbca --- /dev/null +++ b/tutorials/tutorial-01-genome.ipynb @@ -0,0 +1,23 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "286b5fc2-e629-4461-ad95-b69d3d606a78", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "", + "name": "" + }, + "language_info": { + "name": "" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}