delete useless;
append readme
This commit is contained in:
89
README.md
89
README.md
@@ -1,3 +1,88 @@
|
||||
# TensorNEAT: Tensorized NeuroEvolution of Augmenting Topologies for GPU Acceleration
|
||||
# TensorNEAT: Tensorized NEAT implementation in JAX
|
||||
|
||||
TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) algorithm. It provides support for parallel execution of tasks such as forward network computation, mutation, and crossover at the population level.
|
||||
TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies)
|
||||
algorithm. It provides support for parallel execution of tasks such as network forward computation, mutation,
|
||||
and crossover at the population level.
|
||||
|
||||
## Requirements
|
||||
* available [JAX](https://github.com/google/jax#installation) environment;
|
||||
* [gymnax](https://github.com/RobertTLange/gymnax) (optional).
|
||||
|
||||
## Example
|
||||
Simple Example for XOR problem:
|
||||
```python
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.func_fit import XOR, FuncFitConfig
|
||||
|
||||
if __name__ == '__main__':
|
||||
# running config
|
||||
config = Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=-1e-2,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=2,
|
||||
outputs=1
|
||||
),
|
||||
gene=NormalGeneConfig(),
|
||||
problem=FuncFitConfig(
|
||||
error_method='rmse'
|
||||
)
|
||||
)
|
||||
# define algorithm: NEAT with NormalGene
|
||||
algorithm = NEAT(config, NormalGene)
|
||||
# full pipeline
|
||||
pipeline = Pipeline(config, algorithm, XOR)
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
```
|
||||
|
||||
Simple Example for RL envs in gymnax(CartPole-v0):
|
||||
```python
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=4,
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
|
||||
)
|
||||
)
|
||||
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
state, best = pipeline.auto_run(state)
|
||||
```
|
||||
|
||||
`/examples` folder contains more examples.
|
||||
|
||||
## TO BE COMPLETE...
|
||||
@@ -6,7 +6,7 @@ from .state import State
|
||||
|
||||
class Problem:
|
||||
|
||||
jitable: bool
|
||||
jitable = None
|
||||
|
||||
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
|
||||
self.config = problem_config
|
||||
|
||||
@@ -5,6 +5,7 @@ from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.func_fit import XOR, FuncFitConfig
|
||||
|
||||
if __name__ == '__main__':
|
||||
# running config
|
||||
config = Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
@@ -12,13 +13,6 @@ if __name__ == '__main__':
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
max_species=30,
|
||||
conn_add=0.8,
|
||||
conn_delete=0,
|
||||
node_add=0.4,
|
||||
node_delete=0,
|
||||
inputs=2,
|
||||
outputs=1
|
||||
),
|
||||
@@ -27,10 +21,13 @@ if __name__ == '__main__':
|
||||
error_method='rmse'
|
||||
)
|
||||
)
|
||||
|
||||
# define algorithm: NEAT with NormalGene
|
||||
algorithm = NEAT(config, NormalGene)
|
||||
# full pipeline
|
||||
pipeline = Pipeline(config, algorithm, XOR)
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
|
||||
41
examples/general_xor.py
Normal file
41
examples/general_xor.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from config import *
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.func_fit import XOR, FuncFitConfig
|
||||
|
||||
def evaluate():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=-1e-2,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
max_species=30,
|
||||
conn_add=0.8,
|
||||
conn_delete=0,
|
||||
node_add=0.4,
|
||||
node_delete=0,
|
||||
inputs=2,
|
||||
outputs=1
|
||||
),
|
||||
gene=NormalGeneConfig(),
|
||||
problem=FuncFitConfig(
|
||||
error_method='rmse'
|
||||
)
|
||||
)
|
||||
|
||||
algorithm = NEAT(config, NormalGene)
|
||||
pipeline = Pipeline(config, algorithm, XOR)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
pipeline.show(state, best)
|
||||
@@ -29,8 +29,8 @@ class XOR(FuncFit):
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 2)
|
||||
return 4, 2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1)
|
||||
return 4, 1
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import gym
|
||||
|
||||
from core import State
|
||||
from .rl_unjit import RLEnv, RLEnvConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GymConfig(RLEnvConfig):
|
||||
env_name: str = "CartPole-v1"
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered"
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
|
||||
def __init__(self, config: GymConfig = GymConfig()):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.env, self.env_params = gym.make(config.env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
return self.env.step(randkey, env_state, action, self.env_params)
|
||||
|
||||
def env_reset(self, randkey):
|
||||
return self.env.reset(randkey, self.env_params)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.env.observation_space(self.env_params).shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
@@ -1,69 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
|
||||
from config import ProblemConfig
|
||||
|
||||
from core import Problem, State
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RLEnvConfig(ProblemConfig):
|
||||
output_transform: Callable = lambda x: x
|
||||
|
||||
|
||||
class RLEnv(Problem):
|
||||
|
||||
jitable = False
|
||||
|
||||
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _ = carry
|
||||
return ~done
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, _, tr = carry # total reward
|
||||
net_out = act_func(state, obs, params)
|
||||
action = self.config.output_transform(net_out)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def env_reset(self, randkey):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user