update readme.md for environment configuration

This commit is contained in:
wls2002
2024-05-20 20:45:48 +08:00
parent 0e89ed1d7c
commit d1559317d1
3 changed files with 53 additions and 5 deletions

View File

@@ -23,12 +23,15 @@
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
## Requirements
TensorNEAT requires:
- jax (version >= 0.4.16)
- jaxlib (version >= 0.3.0)
- brax [optional]
- gymnax [optional]
Due to the rapid iteration of JAX versions, configuring the runtime environment for tensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
- jax (0.4.28)
- jaxlib (0.4.28+cuda12.cudnn89)
- brax (0.10.3)
- gymnax (0.0.8)
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
## Example
Simple Example for XOR problem:
```python

View File

@@ -0,0 +1,9 @@
brax==0.10.3
flax==0.8.4
gymnax==0.0.8
jax==0.4.28
jaxlib==0.4.28+cuda12.cudnn89
jaxopt==0.8.3
mujoco==3.1.4
mujoco-mjx==3.1.4
optax==0.2.2

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=10,
species_size=10,
),
),
problem=BraxEnv(
env_name='walker2d',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)