update README.md

This commit is contained in:
root
2024-07-15 21:02:58 +08:00
parent 6edf083d4f
commit 0760882aae
9 changed files with 2082 additions and 14 deletions

182
README.md
View File

@@ -36,27 +36,167 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
- Compatible with **EvoX** for multi-device and distributed support.
- Test neuroevolution algorithms on advanced **RL tasks** (Brax, Gymnax).
## Installation
## Examples
### Solving RL Tasks
Using the NEAT algorithm to solve RL tasks. Here are some results:
1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`.
The following animations show the behaviors in Brax environments:
<div style="display: flex; justify-content: space-around; align-items: center;">
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/halfcheetah_animation_200.gif" alt="halfcheetah_animation" width="200">
<figcaption>halfcheetah</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/hopper_animation_200.gif" alt="hopper" width="200">
<figcaption>hopper</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/walker2d_animation_200.gif" alt="walker2d" width="200">
<figcaption>walker2d</figcaption>
</figure>
</div>
For cpu version only, you may use:
The following graphs show the network of the control policy generated by the NEAT algorithm:
<div style="display: flex; justify-content: space-around; align-items: center;">
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/halfcheetah_network.svg" alt="halfcheetah_network">
<figcaption>halfcheetah</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/hopper_network.svg" alt="hopper_network">
<figcaption>hopper</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/walker2d_network.svg" alt="walker2d_network">
<figcaption>walker2d</figcaption>
</figure>
</div>
You can use these codes for running RL task (Brax Hopper) in TensorNEAT:
```python
# Import necessary modules
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import ACT, AGG
# define the pipeline
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.1,
compatibility_threshold=1.0,
genome=DefaultGenome(
num_inputs=11,
num_outputs=3,
init_hidden_layers=(),
node_gene=BiasNode(
activation_options=ACT.tanh,
aggregation_options=AGG.sum,
),
output_transform=ACT.tanh,
),
),
problem=BraxEnv(
env_name="hopper",
max_step=1000,
),
seed=42,
generation_limit=100,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
```
pip install -U jax
More example about RL tasks in TensorNEAT are shown in `./examples/brax` and `./examples/gymnax`.
## Solving Function Fitting Tasks (Symbolic Regression)
You can define your custom function and use the NEAT algorithm to solve the function fitting task.
1. Import necessary modules.
```python
import jax, jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import ACT, AGG
```
For nvidia gpus, you may use:
```
pip install -U "jax[cuda12]"
```
For details of installing jax, please check https://github.com/google/jax.
2. Define a custom function to be fit, and then create the function fitting problem.
```python
def pagie_polynomial(inputs):
x, y = inputs
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
# important! returns an array with one item, NOT a scalar
return jnp.array([res])
2. Install `tensorneat` from the GitHub source code:
custom_problem = CustomFuncFit(
func=pagie_polynomial,
low_bounds=[-1, -1],
upper_bounds=[1, 1],
method="sample",
num_samples=100,
)
```
pip install git+https://github.com/EMI-Group/tensorneat.git
3. Define custom activation function for the NEAT algorithm:
```python
def square(x):
return x ** 2
ACT.add_func("square", square)
```
4. Define the NEAT algorithm:
```python
algorithm=NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
init_hidden_layers=(),
node_gene=BiasNode(
# using (identity, inversion, squre)
# as possible activation funcion
activation_options=[ACT.identity, ACT.inv, ACT.square],
# using (sum, product) as possible aggregation funcion
aggregation_options=[AGG.sum, AGG.product],
),
output_transform=ACT.identity,
),
)
```
5. Define the Pipeline and then run it!
```python
pipeline = Pipeline(
algorithm=algorithm,
problem=custom_problem,
generation_limit=50,
fitness_target=-1e-4,
seed=42,
)
# initialize state
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
```
More example about function fitting tasks in TensorNEAT are shown in `./examples/func_fit`.
## Basic API Usage
Start your journey with TensorNEAT in a few simple steps:
@@ -189,6 +329,26 @@ Using this code, you can run the NEAT algorithm within EvoX and leverage EvoX's
For a complete example, see `./example/with_evox/walker2d_evox.py`, which demonstrates EvoX's multi-device functionality.
## Installation
1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`.
For cpu version only, you may use:
```
pip install -U jax
```
For nvidia gpus, you may use:
```
pip install -U "jax[cuda12]"
```
For details of installing jax, please check https://github.com/google/jax.
2. Install `tensorneat` from the GitHub source code:
```
pip install git+https://github.com/EMI-Group/tensorneat.git
```
## Community & Support