delete useless;

append readme
This commit is contained in:
wls2002
2023-09-15 23:50:10 +08:00
parent 4efa9445d5
commit f217d87ac6
7 changed files with 137 additions and 123 deletions

View File

@@ -1,3 +1,88 @@
# TensorNEAT: Tensorized NeuroEvolution of Augmenting Topologies for GPU Acceleration
# TensorNEAT: Tensorized NEAT implementation in JAX
TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) algorithm. It provides support for parallel execution of tasks such as forward network computation, mutation, and crossover at the population level.
TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies)
algorithm. It provides support for parallel execution of tasks such as network forward computation, mutation,
and crossover at the population level.
## Requirements
* available [JAX](https://github.com/google/jax#installation) environment;
* [gymnax](https://github.com/RobertTLange/gymnax) (optional).
## Example
Simple Example for XOR problem:
```python
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
# running config
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
# define algorithm: NEAT with NormalGene
algorithm = NEAT(config, NormalGene)
# full pipeline
pipeline = Pipeline(config, algorithm, XOR)
# initialize state
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
```
Simple Example for RL envs in gymnax(CartPole-v0):
```python
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
if __name__ == '__main__':
conf = Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
```
`/examples` folder contains more examples.
## TO BE COMPLETE...

View File

@@ -6,7 +6,7 @@ from .state import State
class Problem:
jitable: bool
jitable = None
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
self.config = problem_config

View File

@@ -5,6 +5,7 @@ from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
# running config
config = Config(
basic=BasicConfig(
seed=42,
@@ -12,13 +13,6 @@ if __name__ == '__main__':
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
@@ -27,10 +21,13 @@ if __name__ == '__main__':
error_method='rmse'
)
)
# define algorithm: NEAT with NormalGene
algorithm = NEAT(config, NormalGene)
# full pipeline
pipeline = Pipeline(config, algorithm, XOR)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

41
examples/general_xor.py Normal file
View File

@@ -0,0 +1,41 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
def evaluate():
pass
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm, XOR)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -29,8 +29,8 @@ class XOR(FuncFit):
@property
def input_shape(self):
return (4, 2)
return 4, 2
@property
def output_shape(self):
return (4, 1)
return 4, 1

View File

@@ -1,40 +0,0 @@
from dataclasses import dataclass
from typing import Callable
import gym
from core import State
from .rl_unjit import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class GymConfig(RLEnvConfig):
env_name: str = "CartPole-v1"
def __post_init__(self):
assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered"
class GymNaxEnv(RLEnv):
def __init__(self, config: GymConfig = GymConfig()):
super().__init__(config)
self.config = config
self.env, self.env_params = gym.make(config.env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -1,69 +0,0 @@
from dataclasses import dataclass
from typing import Callable
import jax
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class RLEnvConfig(ProblemConfig):
output_transform: Callable = lambda x: x
class RLEnv(Problem):
jitable = False
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config
def evaluate(self, randkey, state: State, act_func: Callable, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
_, _, _, _, total_reward = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0)
)
return total_reward
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError