Update README.md
This commit is contained in:
72
README.md
72
README.md
@@ -41,48 +41,27 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
|
|||||||
Using the NEAT algorithm to solve RL tasks. Here are some results:
|
Using the NEAT algorithm to solve RL tasks. Here are some results:
|
||||||
|
|
||||||
The following animations show the behaviors in Brax environments:
|
The following animations show the behaviors in Brax environments:
|
||||||
<div style="display: flex; justify-content: space-around; align-items: center;">
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
| <img src="./imgs/halfcheetah_animation_200.gif" alt="halfcheetah" width="200"> | <img src="./imgs/hopper_animation_200.gif" alt="hopper" width="200"> | <img src="./imgs/walker2d_animation_200.gif" alt="walker2d" width="200"> |
|
||||||
<img src="./imgs/halfcheetah_animation_200.gif" alt="halfcheetah_animation" width="200">
|
|:-----------------------------------------------------------------------------:|:--------------------------------------------------------------------:|:------------------------------------------------------------------------:|
|
||||||
<figcaption>halfcheetah</figcaption>
|
| halfcheetah | hopper | walker2d |
|
||||||
</figure>
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
|
||||||
<img src="./imgs/hopper_animation_200.gif" alt="hopper" width="200">
|
|
||||||
<figcaption>hopper</figcaption>
|
|
||||||
</figure>
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
|
||||||
<img src="./imgs/walker2d_animation_200.gif" alt="walker2d" width="200">
|
|
||||||
<figcaption>walker2d</figcaption>
|
|
||||||
</figure>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
The following graphs show the network of the control policy generated by the NEAT algorithm:
|
The following graphs show the network of the control policy generated by the NEAT algorithm:
|
||||||
<div style="display: flex; justify-content: space-around; align-items: center;">
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
|
||||||
<img src="./imgs/halfcheetah_network.svg" alt="halfcheetah_network" width="200">
|
|
||||||
<figcaption>halfcheetah</figcaption>
|
|
||||||
</figure>
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
|
||||||
<img src="./imgs/hopper_network.svg" alt="hopper_network" width="200">
|
|
||||||
<figcaption>hopper</figcaption>
|
|
||||||
</figure>
|
|
||||||
<figure style="text-align: center; margin: 10px;">
|
|
||||||
<img src="./imgs/walker2d_network.svg" alt="walker2d_network" width="200">
|
|
||||||
<figcaption>walker2d</figcaption>
|
|
||||||
</figure>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
You can use these codes for running RL task (Brax Hopper) in TensorNEAT:
|
| <img src="./imgs/halfcheetah_network.svg" alt="halfcheetah_network" width="200"> | <img src="./imgs/hopper_network.svg" alt="hopper_network" width="200"> | <img src="./imgs/walker2d_network.svg" alt="walker2d_network" width="200"> |
|
||||||
|
|:-----------------------------------------------------------------------------:|:--------------------------------------------------------------------:|:------------------------------------------------------------------------:|
|
||||||
|
| halfcheetah | hopper | walker2d |
|
||||||
|
|
||||||
|
You can use these codes for running an RL task (Brax Hopper) in TensorNEAT:
|
||||||
```python
|
```python
|
||||||
# Import necessary modules
|
# Import necessary modules
|
||||||
from tensorneat.pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from tensorneat.algorithm.neat import NEAT
|
from tensorneat.algorithm.neat import NEAT
|
||||||
from tensorneat.genome import DefaultGenome, BiasNode
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
from tensorneat.problem.rl import BraxEnv
|
from tensorneat.problem.rl import BraxEnv
|
||||||
from tensorneat.common import ACT, AGG
|
from tensorneat.common import ACT, AGG
|
||||||
|
|
||||||
# define the pipeline
|
# Define the pipeline
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=NEAT(
|
algorithm=NEAT(
|
||||||
pop_size=1000,
|
pop_size=1000,
|
||||||
@@ -109,21 +88,20 @@ pipeline = Pipeline(
|
|||||||
fitness_target=5000,
|
fitness_target=5000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# Initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
|
|
||||||
# run until terminate
|
# Run until termination
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
```
|
```
|
||||||
More example about RL tasks in TensorNEAT are shown in `./examples/brax` and `./examples/gymnax`.
|
More examples of RL tasks in TensorNEAT can be found in `./examples/brax` and `./examples/gymnax`.
|
||||||
|
|
||||||
## Solving Function Fitting Tasks (Symbolic Regression)
|
## Solving Function Fitting Tasks (Symbolic Regression)
|
||||||
You can define your custom function and use the NEAT algorithm to solve the function fitting task.
|
You can define your custom function and use the NEAT algorithm to solve the function fitting task.
|
||||||
|
|
||||||
1. Import necessary modules.
|
1. Import necessary modules:
|
||||||
```python
|
```python
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
from tensorneat.pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from tensorneat.algorithm.neat import NEAT
|
from tensorneat.algorithm.neat import NEAT
|
||||||
from tensorneat.genome import DefaultGenome, BiasNode
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
@@ -131,13 +109,13 @@ from tensorneat.problem.func_fit import CustomFuncFit
|
|||||||
from tensorneat.common import ACT, AGG
|
from tensorneat.common import ACT, AGG
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Define a custom function to be fit, and then create the function fitting problem.
|
2. Define a custom function to be fit, and then create the function fitting problem:
|
||||||
```python
|
```python
|
||||||
def pagie_polynomial(inputs):
|
def pagie_polynomial(inputs):
|
||||||
x, y = inputs
|
x, y = inputs
|
||||||
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
|
res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
|
||||||
|
|
||||||
# important! returns an array with one item, NOT a scalar
|
# Important! Returns an array with one item, NOT a scalar
|
||||||
return jnp.array([res])
|
return jnp.array([res])
|
||||||
|
|
||||||
custom_problem = CustomFuncFit(
|
custom_problem = CustomFuncFit(
|
||||||
@@ -158,7 +136,7 @@ ACT.add_func("square", square)
|
|||||||
|
|
||||||
4. Define the NEAT algorithm:
|
4. Define the NEAT algorithm:
|
||||||
```python
|
```python
|
||||||
algorithm=NEAT(
|
algorithm = NEAT(
|
||||||
pop_size=10000,
|
pop_size=10000,
|
||||||
species_size=20,
|
species_size=20,
|
||||||
survival_threshold=0.01,
|
survival_threshold=0.01,
|
||||||
@@ -167,10 +145,10 @@ algorithm=NEAT(
|
|||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
init_hidden_layers=(),
|
init_hidden_layers=(),
|
||||||
node_gene=BiasNode(
|
node_gene=BiasNode(
|
||||||
# using (identity, inversion, squre)
|
# Using (identity, inversion, square)
|
||||||
# as possible activation funcion
|
# as possible activation functions
|
||||||
activation_options=[ACT.identity, ACT.inv, ACT.square],
|
activation_options=[ACT.identity, ACT.inv, ACT.square],
|
||||||
# using (sum, product) as possible aggregation funcion
|
# Using (sum, product) as possible aggregation functions
|
||||||
aggregation_options=[AGG.sum, AGG.product],
|
aggregation_options=[AGG.sum, AGG.product],
|
||||||
),
|
),
|
||||||
output_transform=ACT.identity,
|
output_transform=ACT.identity,
|
||||||
@@ -178,7 +156,7 @@ algorithm=NEAT(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Define the Pipeline and then run it!
|
5. Define the Pipeline and then run it:
|
||||||
```python
|
```python
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=algorithm,
|
algorithm=algorithm,
|
||||||
@@ -188,14 +166,14 @@ pipeline = Pipeline(
|
|||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize state
|
# Initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
# run until terminate
|
# Run until termination
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
# show result
|
# Show result
|
||||||
pipeline.show(state, best)
|
pipeline.show(state, best)
|
||||||
```
|
```
|
||||||
More example about function fitting tasks in TensorNEAT are shown in `./examples/func_fit`.
|
More examples of function fitting tasks in TensorNEAT can be found in `./examples/func_fit`.
|
||||||
|
|
||||||
## Basic API Usage
|
## Basic API Usage
|
||||||
Start your journey with TensorNEAT in a few simple steps:
|
Start your journey with TensorNEAT in a few simple steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user