diff --git a/README.md b/README.md index 9b8d0ea..2a0f4a8 100644 --- a/README.md +++ b/README.md @@ -36,27 +36,167 @@ TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (N - Compatible with **EvoX** for multi-device and distributed support. - Test neuroevolution algorithms on advanced **RL tasks** (Brax, Gymnax). -## Installation +## Examples +### Solving RL Tasks +Using the NEAT algorithm to solve RL tasks. Here are some results: -1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`. +The following animations show the behaviors in Brax environments: +
+
+ halfcheetah_animation +
halfcheetah
+
+
+ hopper +
hopper
+
+
+ walker2d +
walker2d
+
+
-For cpu version only, you may use: +The following graphs show the network of the control policy generated by the NEAT algorithm: +
+
+ halfcheetah_network +
halfcheetah
+
+
+ hopper_network +
hopper
+
+
+ walker2d_network +
walker2d
+
+
+ +You can use these codes for running RL task (Brax Hopper) in TensorNEAT: +```python +# Import necessary modules +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl import BraxEnv +from tensorneat.common import ACT, AGG + +# define the pipeline +pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=11, + num_outputs=3, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=ACT.tanh, + aggregation_options=AGG.sum, + ), + output_transform=ACT.tanh, + ), + ), + problem=BraxEnv( + env_name="hopper", + max_step=1000, + ), + seed=42, + generation_limit=100, + fitness_target=5000, + ) + +# initialize state +state = pipeline.setup() + +# run until terminate +state, best = pipeline.auto_run(state) ``` -pip install -U jax +More example about RL tasks in TensorNEAT are shown in `./examples/brax` and `./examples/gymnax`. + +## Solving Function Fitting Tasks (Symbolic Regression) +You can define your custom function and use the NEAT algorithm to solve the function fitting task. + +1. Import necessary modules. +```python +import jax, jax.numpy as jnp + +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode +from tensorneat.problem.func_fit import CustomFuncFit +from tensorneat.common import ACT, AGG ``` -For nvidia gpus, you may use: -``` -pip install -U "jax[cuda12]" -``` -For details of installing jax, please check https://github.com/google/jax. +2. Define a custom function to be fit, and then create the function fitting problem. +```python +def pagie_polynomial(inputs): + x, y = inputs + res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4)) + # important! returns an array with one item, NOT a scalar + return jnp.array([res]) -2. Install `tensorneat` from the GitHub source code: +custom_problem = CustomFuncFit( + func=pagie_polynomial, + low_bounds=[-1, -1], + upper_bounds=[1, 1], + method="sample", + num_samples=100, +) ``` -pip install git+https://github.com/EMI-Group/tensorneat.git + +3. Define custom activation function for the NEAT algorithm: +```python +def square(x): + return x ** 2 +ACT.add_func("square", square) ``` +4. Define the NEAT algorithm: +```python +algorithm=NEAT( + pop_size=10000, + species_size=20, + survival_threshold=0.01, + genome=DefaultGenome( + num_inputs=2, + num_outputs=1, + init_hidden_layers=(), + node_gene=BiasNode( + # using (identity, inversion, squre) + # as possible activation funcion + activation_options=[ACT.identity, ACT.inv, ACT.square], + # using (sum, product) as possible aggregation funcion + aggregation_options=[AGG.sum, AGG.product], + ), + output_transform=ACT.identity, + ), +) +``` + +5. Define the Pipeline and then run it! +```python +pipeline = Pipeline( + algorithm=algorithm, + problem=custom_problem, + generation_limit=50, + fitness_target=-1e-4, + seed=42, +) + +# initialize state +state = pipeline.setup() +# run until terminate +state, best = pipeline.auto_run(state) +# show result +pipeline.show(state, best) +``` +More example about function fitting tasks in TensorNEAT are shown in `./examples/func_fit`. + ## Basic API Usage Start your journey with TensorNEAT in a few simple steps: @@ -189,6 +329,26 @@ Using this code, you can run the NEAT algorithm within EvoX and leverage EvoX's For a complete example, see `./example/with_evox/walker2d_evox.py`, which demonstrates EvoX's multi-device functionality. +## Installation + +1. Install the correct version of [JAX](https://github.com/google/jax). We recommend `jax >= 0.4.28`. + +For cpu version only, you may use: +``` +pip install -U jax +``` + +For nvidia gpus, you may use: +``` +pip install -U "jax[cuda12]" +``` +For details of installing jax, please check https://github.com/google/jax. + + +2. Install `tensorneat` from the GitHub source code: +``` +pip install git+https://github.com/EMI-Group/tensorneat.git +``` ## Community & Support diff --git a/examples/brax/hopper.py b/examples/brax/hopper.py new file mode 100644 index 0000000..3a4fe1e --- /dev/null +++ b/examples/brax/hopper.py @@ -0,0 +1,39 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl import BraxEnv +from tensorneat.common import ACT, AGG + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=11, + num_outputs=3, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=ACT.tanh, + aggregation_options=AGG.sum, + ), + output_transform=ACT.tanh, + ), + ), + problem=BraxEnv( + env_name="hopper", + max_step=1000, + ), + seed=42, + generation_limit=100, + fitness_target=5000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/examples/func_fit/custom_func_fit.py b/examples/func_fit/custom_func_fit.py index 9bbea1d..7b6b437 100644 --- a/examples/func_fit/custom_func_fit.py +++ b/examples/func_fit/custom_func_fit.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from tensorneat.pipeline import Pipeline from tensorneat.algorithm.neat import NEAT -from tensorneat.genome import DefaultGenome, DefaultNode, DefaultMutation, BiasNode +from tensorneat.genome import DefaultGenome, BiasNode from tensorneat.problem.func_fit import CustomFuncFit from tensorneat.common import ACT, AGG @@ -55,5 +55,4 @@ if __name__ == "__main__": # run until terminate state, best = pipeline.auto_run(state) # show result - # pipeline.show(state, best) - print(pipeline.algorithm.genome.repr(state, *best)) + pipeline.show(state, best) diff --git a/imgs/halfcheetah_animation_200.gif b/imgs/halfcheetah_animation_200.gif new file mode 100644 index 0000000..152f8e7 Binary files /dev/null and b/imgs/halfcheetah_animation_200.gif differ diff --git a/imgs/halfcheetah_network.svg b/imgs/halfcheetah_network.svg new file mode 100644 index 0000000..41d55df --- /dev/null +++ b/imgs/halfcheetah_network.svg @@ -0,0 +1,623 @@ + + + + + + + + 2024-07-15T20:16:30.367839 + image/svg+xml + + + Matplotlib v3.9.0, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/imgs/hopper_animation_200.gif b/imgs/hopper_animation_200.gif new file mode 100644 index 0000000..622df83 Binary files /dev/null and b/imgs/hopper_animation_200.gif differ diff --git a/imgs/hopper_network.svg b/imgs/hopper_network.svg new file mode 100644 index 0000000..4c33874 --- /dev/null +++ b/imgs/hopper_network.svg @@ -0,0 +1,439 @@ + + + + + + + + 2024-07-15T19:39:34.749207 + image/svg+xml + + + Matplotlib v3.9.0, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/imgs/walker2d_animation_200.gif b/imgs/walker2d_animation_200.gif new file mode 100644 index 0000000..a804e0b Binary files /dev/null and b/imgs/walker2d_animation_200.gif differ diff --git a/imgs/walker2d_network.svg b/imgs/walker2d_network.svg new file mode 100644 index 0000000..a6c4929 --- /dev/null +++ b/imgs/walker2d_network.svg @@ -0,0 +1,808 @@ + + + + + + + + 2024-07-15T20:07:34.485627 + image/svg+xml + + + Matplotlib v3.9.0, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +