From f217d87ac66133b4d09123bbf3c15ac2ea424489 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 15 Sep 2023 23:50:10 +0800 Subject: [PATCH] delete useless; append readme --- README.md | 89 +++++++++++++++++++++++++++++++++++++- core/problem.py | 2 +- examples/func_fit/xor.py | 15 +++---- examples/general_xor.py | 41 ++++++++++++++++++ problem/func_fit/xor.py | 4 +- problem/rl_env/gym_env.py | 40 ----------------- problem/rl_env/rl_unjit.py | 69 ----------------------------- 7 files changed, 137 insertions(+), 123 deletions(-) create mode 100644 examples/general_xor.py delete mode 100644 problem/rl_env/gym_env.py delete mode 100644 problem/rl_env/rl_unjit.py diff --git a/README.md b/README.md index 9344b38..67561bf 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,88 @@ -# TensorNEAT: Tensorized NeuroEvolution of Augmenting Topologies for GPU Acceleration +# TensorNEAT: Tensorized NEAT implementation in JAX -TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) algorithm. It provides support for parallel execution of tasks such as forward network computation, mutation, and crossover at the population level. +TensorNEAT is a powerful tool that utilizes JAX to implement the NEAT (NeuroEvolution of Augmenting Topologies) +algorithm. It provides support for parallel execution of tasks such as network forward computation, mutation, +and crossover at the population level. + +## Requirements +* available [JAX](https://github.com/google/jax#installation) environment; +* [gymnax](https://github.com/RobertTLange/gymnax) (optional). + +## Example +Simple Example for XOR problem: +```python +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.func_fit import XOR, FuncFitConfig + +if __name__ == '__main__': + # running config + config = Config( + basic=BasicConfig( + seed=42, + fitness_target=-1e-2, + pop_size=10000 + ), + neat=NeatConfig( + inputs=2, + outputs=1 + ), + gene=NormalGeneConfig(), + problem=FuncFitConfig( + error_method='rmse' + ) + ) + # define algorithm: NEAT with NormalGene + algorithm = NEAT(config, NormalGene) + # full pipeline + pipeline = Pipeline(config, algorithm, XOR) + # initialize state + state = pipeline.setup() + # run until terminate + state, best = pipeline.auto_run(state) + # show result + pipeline.show(state, best) +``` + +Simple Example for RL envs in gymnax(CartPole-v0): +```python +import jax.numpy as jnp + +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + +if __name__ == '__main__': + conf = Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default=Act.sigmoid, + activation_options=(Act.sigmoid,), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1} + ) + ) + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + state, best = pipeline.auto_run(state) +``` + +`/examples` folder contains more examples. + +## TO BE COMPLETE... \ No newline at end of file diff --git a/core/problem.py b/core/problem.py index 87d4396..f65af85 100644 --- a/core/problem.py +++ b/core/problem.py @@ -6,7 +6,7 @@ from .state import State class Problem: - jitable: bool + jitable = None def __init__(self, problem_config: ProblemConfig = ProblemConfig()): self.config = problem_config diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index a2d45ee..11429ae 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -5,6 +5,7 @@ from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.func_fit import XOR, FuncFitConfig if __name__ == '__main__': + # running config config = Config( basic=BasicConfig( seed=42, @@ -12,13 +13,6 @@ if __name__ == '__main__': pop_size=10000 ), neat=NeatConfig( - max_nodes=50, - max_conns=100, - max_species=30, - conn_add=0.8, - conn_delete=0, - node_add=0.4, - node_delete=0, inputs=2, outputs=1 ), @@ -27,10 +21,13 @@ if __name__ == '__main__': error_method='rmse' ) ) - + # define algorithm: NEAT with NormalGene algorithm = NEAT(config, NormalGene) + # full pipeline pipeline = Pipeline(config, algorithm, XOR) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) + # run until terminate state, best = pipeline.auto_run(state) + # show result pipeline.show(state, best) diff --git a/examples/general_xor.py b/examples/general_xor.py new file mode 100644 index 0000000..edcb994 --- /dev/null +++ b/examples/general_xor.py @@ -0,0 +1,41 @@ +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.func_fit import XOR, FuncFitConfig + +def evaluate(): + pass + + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + seed=42, + fitness_target=-1e-2, + pop_size=10000 + ), + neat=NeatConfig( + max_nodes=50, + max_conns=100, + max_species=30, + conn_add=0.8, + conn_delete=0, + node_add=0.4, + node_delete=0, + inputs=2, + outputs=1 + ), + gene=NormalGeneConfig(), + problem=FuncFitConfig( + error_method='rmse' + ) + ) + + algorithm = NEAT(config, NormalGene) + pipeline = Pipeline(config, algorithm, XOR) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) + pipeline.show(state, best) diff --git a/problem/func_fit/xor.py b/problem/func_fit/xor.py index c41cc65..fed943b 100644 --- a/problem/func_fit/xor.py +++ b/problem/func_fit/xor.py @@ -29,8 +29,8 @@ class XOR(FuncFit): @property def input_shape(self): - return (4, 2) + return 4, 2 @property def output_shape(self): - return (4, 1) + return 4, 1 diff --git a/problem/rl_env/gym_env.py b/problem/rl_env/gym_env.py deleted file mode 100644 index 080152a..0000000 --- a/problem/rl_env/gym_env.py +++ /dev/null @@ -1,40 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -import gym - -from core import State -from .rl_unjit import RLEnv, RLEnvConfig - - -@dataclass(frozen=True) -class GymConfig(RLEnvConfig): - env_name: str = "CartPole-v1" - - def __post_init__(self): - assert self.env_name in gym.registered_envs, f"Env {self.env_name} not registered" - - -class GymNaxEnv(RLEnv): - - def __init__(self, config: GymConfig = GymConfig()): - super().__init__(config) - self.config = config - self.env, self.env_params = gym.make(config.env_name) - - def env_step(self, randkey, env_state, action): - return self.env.step(randkey, env_state, action, self.env_params) - - def env_reset(self, randkey): - return self.env.reset(randkey, self.env_params) - - @property - def input_shape(self): - return self.env.observation_space(self.env_params).shape - - @property - def output_shape(self): - return self.env.action_space(self.env_params).shape - - def show(self, randkey, state: State, act_func: Callable, params): - raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/problem/rl_env/rl_unjit.py b/problem/rl_env/rl_unjit.py deleted file mode 100644 index 0ab7c3f..0000000 --- a/problem/rl_env/rl_unjit.py +++ /dev/null @@ -1,69 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -import jax - -from config import ProblemConfig - -from core import Problem, State - - -@dataclass(frozen=True) -class RLEnvConfig(ProblemConfig): - output_transform: Callable = lambda x: x - - -class RLEnv(Problem): - - jitable = False - - def __init__(self, config: RLEnvConfig = RLEnvConfig()): - super().__init__(config) - self.config = config - - def evaluate(self, randkey, state: State, act_func: Callable, params): - rng_reset, rng_episode = jax.random.split(randkey) - init_obs, init_env_state = self.reset(rng_reset) - - def cond_func(carry): - _, _, _, done, _ = carry - return ~done - - def body_func(carry): - obs, env_state, rng, _, tr = carry # total reward - net_out = act_func(state, obs, params) - action = self.config.output_transform(net_out) - next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) - next_rng, _ = jax.random.split(rng) - return next_obs, next_env_state, next_rng, done, tr + reward - - _, _, _, _, total_reward = jax.lax.while_loop( - cond_func, - body_func, - (init_obs, init_env_state, rng_episode, False, 0.0) - ) - - return total_reward - - def step(self, randkey, env_state, action): - return self.env_step(randkey, env_state, action) - - def reset(self, randkey): - return self.env_reset(randkey) - - def env_step(self, randkey, env_state, action): - raise NotImplementedError - - def env_reset(self, randkey): - raise NotImplementedError - - @property - def input_shape(self): - raise NotImplementedError - - @property - def output_shape(self): - raise NotImplementedError - - def show(self, randkey, state: State, act_func: Callable, params): - raise NotImplementedError