update readme

This commit is contained in:
root
2024-07-12 06:12:35 +08:00
parent fb60d694ba
commit 58c56ab2ab
2 changed files with 76 additions and 108 deletions

130
README.md
View File

@@ -22,94 +22,66 @@
## Introduction ## Introduction
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok. TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
## Requirements ## Key Features
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries: - JAX-based network for neuroevolution:
- **Batch inference** across networks with different architectures, GPU-accelerated.
- Evolve networks with **irregular structures** and **fully customize** their behavior.
- Visualize the network and represent it in **mathematical formulas**.
- jax (0.4.28) - GPU-accelerated NEAT implementation:
- jaxlib (0.4.28+cuda12.cudnn89) - Run NEAT and HyperNEAT on GPUs.
- brax (0.10.3) - Achieve **500x** speedup compared to CPU-based NEAT libraries.
- gymnax (0.0.8)
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference. - Rich in extended content:
- Compatible with **EvoX** for multi-device and distributed support.
- Test neuroevolution algorithms on advanced **RL tasks** (Brax, Gymnax).
## Example ## Basic API Usage
Simple Example for XOR problem: Start your journey with TensorNEAT in a few simple steps:
1. **Import necessary modules**:
```python ```python
from pipeline import Pipeline from tensorneat.pipeline import Pipeline
from algorithm.neat import * from tensorneat import algorithm, genome, problem, common
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
``` ```
Simple Example for RL envs in Brax (Ant): 2. **Configure the NEAT algorithm and define a problem**:
```python ```python
from pipeline import Pipeline algorithm = algorithm.NEAT(
from algorithm.neat import * pop_size=10000,
species_size=20,
from problem.rl_env import BraxEnv survival_threshold=0.01,
from tensorneat.utils import ACT genome=genome.DefaultGenome(
num_inputs=3,
if __name__ == '__main__': num_outputs=1,
pipeline = Pipeline( output_transform=common.ACT.sigmoid,
algorithm=NEAT( ),
species=DefaultSpecies( )
genome=DefaultGenome( problem = problem.XOR3d()
num_inputs=27, ```
num_outputs=8,
max_nodes=50, 3. **Initialize the pipeline and run**:
max_conns=100, ```python
node_gene=DefaultNodeGene( pipeline = Pipeline(
activation_options=(ACT.tanh,), algorithm,
activation_default=ACT.tanh, problem,
) generation_limit=200,
), fitness_target=-1e-6,
pop_size=1000, seed=42,
species_size=10, )
), state = pipeline.setup()
), # run until termination
problem=BraxEnv( state, best = pipeline.auto_run(state)
env_name='ant', # show results
), pipeline.show(state, best)
generation_limit=10000, ```
fitness_target=5000
) ## Installation
Install `tensorneat` from the GitHub source code:
# initialize state ```
state = pipeline.setup() pip install git+https://github.com/EMI-Group/tensorneat.git
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
``` ```
more examples are in `tensorneat/examples`.
## Community & Support ## Community & Support

View File

@@ -1,31 +1,27 @@
from tensorneat.pipeline import Pipeline from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT from tensorneat import algorithm, genome, problem, common
from tensorneat.genome import DefaultGenome
from tensorneat.problem.func_fit import XOR3d
from tensorneat.common import ACT
if __name__ == "__main__": algorithm = algorithm.NEAT(
pipeline = Pipeline( pop_size=10000,
algorithm=NEAT( species_size=20,
pop_size=10000, survival_threshold=0.01,
species_size=20, genome=genome.DefaultGenome(
survival_threshold=0.01, num_inputs=3,
genome=DefaultGenome( num_outputs=1,
num_inputs=3, output_transform=common.ACT.sigmoid,
num_outputs=1, ),
init_hidden_layers=(), )
output_transform=ACT.sigmoid, problem = problem.XOR3d()
),
),
problem=XOR3d(),
generation_limit=500,
fitness_target=-1e-6, # float32 precision
seed=42,
)
# initialize state pipeline = Pipeline(
state = pipeline.setup() algorithm,
# run until terminate problem,
state, best = pipeline.auto_run(state) generation_limit=200,
# show result fitness_target=-1e-6,
pipeline.show(state, best) seed=42,
)
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)