From d1559317d18b64629fb68a052c31486fcff598df Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 20 May 2024 20:45:48 +0800 Subject: [PATCH] update readme.md for environment configuration --- README.md | 13 ++++++----- recommend_environment.txt | 9 ++++++++ tensorneat/examples/brax/walker.py | 36 ++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 recommend_environment.txt create mode 100644 tensorneat/examples/brax/walker.py diff --git a/README.md b/README.md index 1eac198..54002dd 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,15 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. ## Requirements -TensorNEAT requires: -- jax (version >= 0.4.16) -- jaxlib (version >= 0.3.0) -- brax [optional] -- gymnax [optional] +Due to the rapid iteration of JAX versions, configuring the runtime environment for tensorNEAT can be challenging. We recommend the following versions for the relevant libraries: + +- jax (0.4.28) +- jaxlib (0.4.28+cuda12.cudnn89) +- brax (0.10.3) +- gymnax (0.0.8) +We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference. + ## Example Simple Example for XOR problem: ```python diff --git a/recommend_environment.txt b/recommend_environment.txt new file mode 100644 index 0000000..449aae4 --- /dev/null +++ b/recommend_environment.txt @@ -0,0 +1,9 @@ +brax==0.10.3 +flax==0.8.4 +gymnax==0.0.8 +jax==0.4.28 +jaxlib==0.4.28+cuda12.cudnn89 +jaxopt==0.8.3 +mujoco==3.1.4 +mujoco-mjx==3.1.4 +optax==0.2.2 \ No newline at end of file diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py new file mode 100644 index 0000000..4cdd645 --- /dev/null +++ b/tensorneat/examples/brax/walker.py @@ -0,0 +1,36 @@ +from pipeline import Pipeline +from algorithm.neat import * + +from problem.rl_env import BraxEnv +from utils import Act + +if __name__ == '__main__': + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=17, + num_outputs=6, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=10, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='walker2d', + ), + generation_limit=10000, + fitness_target=5000 + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file