diff --git a/README.md b/README.md
index 9b8d0ea..2a0f4a8 100644
--- a/README.md
+++ b/README.md
@@ -36,27 +36,167 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N
- Compatible with **EvoX** for multi-device and distributed support.
- Test neuroevolution algorithms on advanced **RL tasks** (Brax, Gymnax).
-## Installation
+## Examples
+### Solving RL Tasks
+Using the NEAT algorithm to solve RL tasks. Here are some results:
-1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`.
+The following animations show the behaviors in Brax environments:
+
+
+
+ halfcheetah
+
+
+
+ hopper
+
+
+
+ walker2d
+
+
-For cpu version only, you may use:
+The following graphs show the network of the control policy generated by the NEAT algorithm:
+
+
+
+ halfcheetah
+
+
+
+ hopper
+
+
+
+ walker2d
+
+
+
+You can use these codes for running RL task (Brax Hopper) in TensorNEAT:
+```python
+# Import necessary modules
+from tensorneat.pipeline import Pipeline
+from tensorneat.algorithm.neat import NEAT
+from tensorneat.genome import DefaultGenome, BiasNode
+
+from tensorneat.problem.rl import BraxEnv
+from tensorneat.common import ACT, AGG
+
+# define the pipeline
+pipeline = Pipeline(
+ algorithm=NEAT(
+ pop_size=1000,
+ species_size=20,
+ survival_threshold=0.1,
+ compatibility_threshold=1.0,
+ genome=DefaultGenome(
+ num_inputs=11,
+ num_outputs=3,
+ init_hidden_layers=(),
+ node_gene=BiasNode(
+ activation_options=ACT.tanh,
+ aggregation_options=AGG.sum,
+ ),
+ output_transform=ACT.tanh,
+ ),
+ ),
+ problem=BraxEnv(
+ env_name="hopper",
+ max_step=1000,
+ ),
+ seed=42,
+ generation_limit=100,
+ fitness_target=5000,
+ )
+
+# initialize state
+state = pipeline.setup()
+
+# run until terminate
+state, best = pipeline.auto_run(state)
```
-pip install -U jax
+More example about RL tasks in TensorNEAT are shown in `./examples/brax` and `./examples/gymnax`.
+
+## Solving Function Fitting Tasks (Symbolic Regression)
+You can define your custom function and use the NEAT algorithm to solve the function fitting task.
+
+1. Import necessary modules.
+```python
+import jax, jax.numpy as jnp
+
+from tensorneat.pipeline import Pipeline
+from tensorneat.algorithm.neat import NEAT
+from tensorneat.genome import DefaultGenome, BiasNode
+from tensorneat.problem.func_fit import CustomFuncFit
+from tensorneat.common import ACT, AGG
```
-For nvidia gpus, you may use:
-```
-pip install -U "jax[cuda12]"
-```
-For details of installing jax, please check https://github.com/google/jax.
+2. Define a custom function to be fit, and then create the function fitting problem.
+```python
+def pagie_polynomial(inputs):
+ x, y = inputs
+ res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))
+ # important! returns an array with one item, NOT a scalar
+ return jnp.array([res])
-2. Install `tensorneat` from the GitHub source code:
+custom_problem = CustomFuncFit(
+ func=pagie_polynomial,
+ low_bounds=[-1, -1],
+ upper_bounds=[1, 1],
+ method="sample",
+ num_samples=100,
+)
```
-pip install git+https://github.com/EMI-Group/tensorneat.git
+
+3. Define custom activation function for the NEAT algorithm:
+```python
+def square(x):
+ return x ** 2
+ACT.add_func("square", square)
```
+4. Define the NEAT algorithm:
+```python
+algorithm=NEAT(
+ pop_size=10000,
+ species_size=20,
+ survival_threshold=0.01,
+ genome=DefaultGenome(
+ num_inputs=2,
+ num_outputs=1,
+ init_hidden_layers=(),
+ node_gene=BiasNode(
+ # using (identity, inversion, squre)
+ # as possible activation funcion
+ activation_options=[ACT.identity, ACT.inv, ACT.square],
+ # using (sum, product) as possible aggregation funcion
+ aggregation_options=[AGG.sum, AGG.product],
+ ),
+ output_transform=ACT.identity,
+ ),
+)
+```
+
+5. Define the Pipeline and then run it!
+```python
+pipeline = Pipeline(
+ algorithm=algorithm,
+ problem=custom_problem,
+ generation_limit=50,
+ fitness_target=-1e-4,
+ seed=42,
+)
+
+# initialize state
+state = pipeline.setup()
+# run until terminate
+state, best = pipeline.auto_run(state)
+# show result
+pipeline.show(state, best)
+```
+More example about function fitting tasks in TensorNEAT are shown in `./examples/func_fit`.
+
## Basic API Usage
Start your journey with TensorNEAT in a few simple steps:
@@ -189,6 +329,26 @@ Using this code, you can run the NEAT algorithm within EvoX and leverage EvoX's
For a complete example, see `./example/with_evox/walker2d_evox.py`, which demonstrates EvoX's multi-device functionality.
+## Installation
+
+1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`.
+
+For cpu version only, you may use:
+```
+pip install -U jax
+```
+
+For nvidia gpus, you may use:
+```
+pip install -U "jax[cuda12]"
+```
+For details of installing jax, please check https://github.com/google/jax.
+
+
+2. Install `tensorneat` from the GitHub source code:
+```
+pip install git+https://github.com/EMI-Group/tensorneat.git
+```
## Community & Support
diff --git a/examples/brax/hopper.py b/examples/brax/hopper.py
new file mode 100644
index 0000000..3a4fe1e
--- /dev/null
+++ b/examples/brax/hopper.py
@@ -0,0 +1,39 @@
+from tensorneat.pipeline import Pipeline
+from tensorneat.algorithm.neat import NEAT
+from tensorneat.genome import DefaultGenome, BiasNode
+
+from tensorneat.problem.rl import BraxEnv
+from tensorneat.common import ACT, AGG
+
+if __name__ == "__main__":
+ pipeline = Pipeline(
+ algorithm=NEAT(
+ pop_size=1000,
+ species_size=20,
+ survival_threshold=0.1,
+ compatibility_threshold=1.0,
+ genome=DefaultGenome(
+ num_inputs=11,
+ num_outputs=3,
+ init_hidden_layers=(),
+ node_gene=BiasNode(
+ activation_options=ACT.tanh,
+ aggregation_options=AGG.sum,
+ ),
+ output_transform=ACT.tanh,
+ ),
+ ),
+ problem=BraxEnv(
+ env_name="hopper",
+ max_step=1000,
+ ),
+ seed=42,
+ generation_limit=100,
+ fitness_target=5000,
+ )
+
+ # initialize state
+ state = pipeline.setup()
+ # print(state)
+ # run until terminate
+ state, best = pipeline.auto_run(state)
diff --git a/examples/func_fit/custom_func_fit.py b/examples/func_fit/custom_func_fit.py
index 9bbea1d..7b6b437 100644
--- a/examples/func_fit/custom_func_fit.py
+++ b/examples/func_fit/custom_func_fit.py
@@ -2,7 +2,7 @@ import jax.numpy as jnp
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
-from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode
+from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import ACT, AGG
@@ -55,5 +55,4 @@ if __name__ == "__main__":
# run until terminate
state, best = pipeline.auto_run(state)
# show result
- # pipeline.show(state, best)
- print(pipeline.algorithm.genome.repr(state, *best))
+ pipeline.show(state, best)
diff --git a/imgs/halfcheetah_animation_200.gif b/imgs/halfcheetah_animation_200.gif
new file mode 100644
index 0000000..152f8e7
Binary files /dev/null and b/imgs/halfcheetah_animation_200.gif differ
diff --git a/imgs/halfcheetah_network.svg b/imgs/halfcheetah_network.svg
new file mode 100644
index 0000000..41d55df
--- /dev/null
+++ b/imgs/halfcheetah_network.svg
@@ -0,0 +1,623 @@
+
+
+
+
+
+
+
+ 2024-07-15T20:16:30.367839
+ image/svg+xml
+
+
+ Matplotlib v3.9.0, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/imgs/hopper_animation_200.gif b/imgs/hopper_animation_200.gif
new file mode 100644
index 0000000..622df83
Binary files /dev/null and b/imgs/hopper_animation_200.gif differ
diff --git a/imgs/hopper_network.svg b/imgs/hopper_network.svg
new file mode 100644
index 0000000..4c33874
--- /dev/null
+++ b/imgs/hopper_network.svg
@@ -0,0 +1,439 @@
+
+
+
+
+
+
+
+ 2024-07-15T19:39:34.749207
+ image/svg+xml
+
+
+ Matplotlib v3.9.0, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/imgs/walker2d_animation_200.gif b/imgs/walker2d_animation_200.gif
new file mode 100644
index 0000000..a804e0b
Binary files /dev/null and b/imgs/walker2d_animation_200.gif differ
diff --git a/imgs/walker2d_network.svg b/imgs/walker2d_network.svg
new file mode 100644
index 0000000..a6c4929
--- /dev/null
+++ b/imgs/walker2d_network.svg
@@ -0,0 +1,808 @@
+
+
+
+
+
+
+
+ 2024-07-15T20:07:34.485627
+ image/svg+xml
+
+
+ Matplotlib v3.9.0, https://matplotlib.org/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+