change repo structure; modify readme

This commit is contained in:
wls2002
2024-03-26 21:58:27 +08:00
parent 6970e6a6d5
commit 47dbcbea80
69 changed files with 74 additions and 60 deletions

128
README.md
View File

@@ -1,88 +1,106 @@
# NEATax: Tensorized NEAT implementation in JAX
<h1 align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="./imgs/evox_logo_dark.png">
<source media="(prefers-color-scheme: light)" srcset="./imgs/evox_logo_light.png">
<img alt="EvoX Logo" height="50" src="./imgs/evox_logo_light.png">
</picture>
<br>
</h1>
TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies)
algorithm. It provides support for parallel execution of tasks such as network forward computation, mutation,
and crossover at the population level.
# TensorNEAT: Tensorized NEAT implementation in JAX
<p align="center">
<a href="https://arxiv.org/">
<img src="https://img.shields.io/badge/paper-arxiv-red?style=for-the-badge" alt="TensorRVEA Paper on arXiv">
</a>
</p>
## Introduction
🚀TensorNEAT, a part of EvoX project, aims to enhance the NEAT (NeuroEvolution of Augmenting Topologies) algorithm by incorporating GPU acceleration. Utilizing JAX for parallel computations, it extends NEAT's capabilities to modern computational environments, making advanced neuroevolution accessible and fast.
## Requirements
* available [JAX](https://github.com/google/jax#installation) environment;
* [gymnax](https://github.com/RobertTLange/gymnax) (optional).
TensorNEAT requires:
- jax (version >= 0.4.16)
- jaxlib (version >= 0.3.0)
- brax [optional]
- gymnax [optional]
## Example
Simple Example for XOR problem:
```python
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
from algorithm.neat import *
from problem.func_fit import XOR3d
if __name__ == '__main__':
# running config
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
neat=NeatConfig(
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# define algorithm: NEAT with NormalGene
algorithm = NEAT(config, NormalGene)
# full pipeline
pipeline = Pipeline(config, algorithm, XOR)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
```
Simple Example for RL envs in gymnax(CartPole-v0):
Simple Example for RL envs in Brax (Ant):
```python
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
conf = Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
neat=NeatConfig(
inputs=4,
outputs=1,
problem=BraxEnv(
env_name='ant',
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
generation_limit=10000,
fitness_target=5000
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
```
`/examples` folder contains more examples.
more examples are in `tensorneat/examples`.
## TO BE COMPLETE...

BIN
imgs/evox_logo_dark.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

BIN
imgs/evox_logo_light.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

4
t.py
View File

@@ -1,4 +0,0 @@
import jax.numpy as jnp
a = jnp.zeros((0, 9, 9))
print(a)

View File

Before

Width:  |  Height:  |  Size: 25 MiB

After

Width:  |  Height:  |  Size: 25 MiB