Update README.md

This commit is contained in:
WLS2002
2024-07-15 21:25:31 +08:00
committed by GitHub
parent a2dd842431
commit 6491d329af

View File

@@ -41,48 +41,27 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
Using the NEAT algorithm to solve RL tasks. Here are some results:
The following animations show the behaviors in Brax environments:
<div style="display: flex; justify-content: space-around; align-items: center;">
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/halfcheetah_animation_200.gif" alt="halfcheetah_animation" width="200">
<figcaption>halfcheetah</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/hopper_animation_200.gif" alt="hopper" width="200">
<figcaption>hopper</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/walker2d_animation_200.gif" alt="walker2d" width="200">
<figcaption>walker2d</figcaption>
</figure>
</div>
| <img src="./imgs/halfcheetah_animation_200.gif" alt="halfcheetah" width="200"> | <img src="./imgs/hopper_animation_200.gif" alt="hopper" width="200"> | <img src="./imgs/walker2d_animation_200.gif" alt="walker2d" width="200"> |
|:-----------------------------------------------------------------------------:|:--------------------------------------------------------------------:|:------------------------------------------------------------------------:|
| halfcheetah | hopper | walker2d |
The following graphs show the network of the control policy generated by the NEAT algorithm:
<div style="display: flex; justify-content: space-around; align-items: center;">
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/halfcheetah_network.svg" alt="halfcheetah_network" width="200">
<figcaption>halfcheetah</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/hopper_network.svg" alt="hopper_network" width="200">
<figcaption>hopper</figcaption>
</figure>
<figure style="text-align: center; margin: 10px;">
<img src="./imgs/walker2d_network.svg" alt="walker2d_network" width="200">
<figcaption>walker2d</figcaption>
</figure>
</div>
You can use these codes for running RL task (Brax Hopper) in TensorNEAT:
| <img src="./imgs/halfcheetah_network.svg" alt="halfcheetah_network" width="200"> | <img src="./imgs/hopper_network.svg" alt="hopper_network" width="200"> | <img src="./imgs/walker2d_network.svg" alt="walker2d_network" width="200"> |
|:-----------------------------------------------------------------------------:|:--------------------------------------------------------------------:|:------------------------------------------------------------------------:|
| halfcheetah | hopper | walker2d |
You can use these codes for running an RL task (Brax Hopper) in TensorNEAT:
```python
# Import necessary modules
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import ACT, AGG
# define the pipeline
# Define the pipeline
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
@@ -109,21 +88,20 @@ pipeline = Pipeline(
fitness_target=5000,
)
# initialize state
# Initialize state
state = pipeline.setup()
# run until terminate
# Run until termination
state, best = pipeline.auto_run(state)
```
More example about RL tasks in TensorNEAT are shown in `./examples/brax` and `./examples/gymnax`.
More examples of RL tasks in TensorNEAT can be found in `./examples/brax` and `./examples/gymnax`.
## Solving Function Fitting Tasks (Symbolic Regression)
You can define your custom function and use the NEAT algorithm to solve the function fitting task.
1. Import necessary modules.
1. Import necessary modules:
```python
import jax, jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
@@ -131,13 +109,13 @@ from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import ACT, AGG
```
2. Define a custom function to be fit, and then create the function fitting problem.
2. Define a custom function to be fit, and then create the function fitting problem:
```python
def pagie_polynomial(inputs):
x, y = inputs
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
# important! returns an array with one item, NOT a scalar
# Important! Returns an array with one item, NOT a scalar
return jnp.array([res])
custom_problem = CustomFuncFit(
@@ -167,10 +145,10 @@ algorithm=NEAT(
num_outputs=1,
init_hidden_layers=(),
node_gene=BiasNode(
# using (identity, inversion, squre)
# as possible activation funcion
# Using (identity, inversion, square)
# as possible activation functions
activation_options=[ACT.identity, ACT.inv, ACT.square],
# using (sum, product) as possible aggregation funcion
# Using (sum, product) as possible aggregation functions
aggregation_options=[AGG.sum, AGG.product],
),
output_transform=ACT.identity,
@@ -178,7 +156,7 @@ algorithm=NEAT(
)
```
5. Define the Pipeline and then run it!
5. Define the Pipeline and then run it:
```python
pipeline = Pipeline(
algorithm=algorithm,
@@ -188,14 +166,14 @@ pipeline = Pipeline(
seed=42,
)
# initialize state
# Initialize state
state = pipeline.setup()
# run until terminate
# Run until termination
state, best = pipeline.auto_run(state)
# show result
# Show result
pipeline.show(state, best)
```
More example about function fitting tasks in TensorNEAT are shown in `./examples/func_fit`.
More examples of function fitting tasks in TensorNEAT can be found in `./examples/func_fit`.
## Basic API Usage
Start your journey with TensorNEAT in a few simple steps: