From 0e44b132910408427d71dfbadc2659dcc08dd6db Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 4 Aug 2023 17:29:36 +0800 Subject: [PATCH] remove create_func.... --- algorithm/hyperneat/hyperneat.py | 19 ++-- algorithm/hyperneat/substrate/normal.py | 6 +- algorithm/neat/__init__.py | 1 + algorithm/neat/gene/normal.py | 4 +- algorithm/neat/gene/recurrent.py | 2 +- algorithm/neat/neat.py | 8 +- config/config.py | 14 +-- core/gene.py | 2 +- core/genome.py | 6 +- core/problem.py | 20 +++- core/substrate.py | 4 +- examples/__init__.py | 0 examples/func_fit/xor.py | 36 ++++++++ examples/func_fit/xor_hyperneat.py | 40 ++++++++ examples/func_fit/xor_recurrent.py | 40 ++++++++ examples/gymnax/cartpole.py | 84 +++++++++++++++++ examples/test.py | 31 ------- examples/xor.py | 40 -------- examples/xor_hyperneat.py | 49 ---------- examples/xor_recurrent.py | 42 --------- pipeline.py | 116 +++++++++++++++--------- problem/func_fit/__init__.py | 3 + problem/func_fit/func_fit.py | 69 ++++++++++++++ problem/func_fit/func_fitting.py | 21 ----- problem/func_fit/xor.py | 36 ++++++++ problem/func_fit/xor3d.py | 44 +++++++++ problem/rl_env/__init__.py | 1 + problem/rl_env/gymnax_env.py | 42 +++++++++ problem/rl_env/rl_env.py | 70 ++++++++++++++ 29 files changed, 591 insertions(+), 259 deletions(-) delete mode 100644 examples/__init__.py create mode 100644 examples/func_fit/xor.py create mode 100644 examples/func_fit/xor_hyperneat.py create mode 100644 examples/func_fit/xor_recurrent.py create mode 100644 examples/gymnax/cartpole.py delete mode 100644 examples/test.py delete mode 100644 examples/xor.py delete mode 100644 examples/xor_hyperneat.py delete mode 100644 examples/xor_recurrent.py create mode 100644 problem/func_fit/func_fit.py delete mode 100644 problem/func_fit/func_fitting.py create mode 100644 problem/func_fit/xor.py create mode 100644 problem/func_fit/xor3d.py create mode 100644 problem/rl_env/__init__.py create mode 100644 problem/rl_env/gymnax_env.py create mode 100644 problem/rl_env/rl_env.py diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py index 5695edc..ac48bad 100644 --- a/algorithm/hyperneat/hyperneat.py +++ b/algorithm/hyperneat/hyperneat.py @@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap import numpy as np from config import Config, HyperNeatConfig -from core import Algorithm, Substrate, State, Genome +from core import Algorithm, Substrate, State, Genome, Gene from utils import Activation, Aggregation -from algorithm.neat import NEAT from .substrate import analysis_substrate +from algorithm import NEAT class HyperNEAT(Algorithm): - def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]): + def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]): self.config = config - self.neat = neat + self.neat = NEAT(config, gene) self.substrate = substrate def setup(self, randkey, state=State()): neat_key, randkey = jax.random.split(randkey) state = state.update( - below_threshold=self.config.hyper_neat.below_threshold, - max_weight=self.config.hyper_neat.max_weight, + below_threshold=self.config.hyperneat.below_threshold, + max_weight=self.config.hyperneat.max_weight, ) state = self.neat.setup(neat_key, state) state = self.substrate.setup(self.config.substrate, state) - assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias - assert self.config.hyper_neat.outputs == state.output_coors.shape[0] + assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias + assert self.config.hyperneat.outputs == state.output_coors.shape[0] h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] @@ -53,7 +53,7 @@ class HyperNEAT(Algorithm): return self.neat.tell(state, fitness) def forward(self, state, inputs: Array, transformed: Array): - return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed) + return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed) def forward_transform(self, state: State, genome: Genome): t = self.neat.forward_transform(state, genome) @@ -68,6 +68,7 @@ class HyperNEAT(Algorithm): query_res = query_res / (1 - state.below_threshold) * state.max_weight h_conns = state.h_conns.at[:, 2:].set(query_res) + return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns)) diff --git a/algorithm/hyperneat/substrate/normal.py b/algorithm/hyperneat/substrate/normal.py index 7484fcb..c06edc3 100644 --- a/algorithm/hyperneat/substrate/normal.py +++ b/algorithm/hyperneat/substrate/normal.py @@ -9,9 +9,9 @@ from config import SubstrateConfig @dataclass(frozen=True) class NormalSubstrateConfig(SubstrateConfig): - input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1)) - hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0)) - output_coors: Tuple[Tuple[float]] = ((0, 1),) + input_coors: Tuple = ((-1, -1), (0, -1), (1, -1)) + hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0)) + output_coors: Tuple = ((0, 1),) class NormalSubstrate(Substrate): diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index 6fe56c9..d6bb53c 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1 +1,2 @@ from .neat import NEAT +from .gene import * diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 84973c9..2d9caf0 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -66,7 +66,7 @@ class NormalGene(Gene): node_attrs = ['bias', 'response', 'aggregation', 'activation'] conn_attrs = ['weight'] - def __init__(self, config: NormalGeneConfig): + def __init__(self, config: NormalGeneConfig = NormalGeneConfig()): self.config = config self.act_funcs = [Activation.name2func[name] for name in config.activation_options] self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] @@ -101,7 +101,7 @@ class NormalGene(Gene): ) def update(self, state): - pass + return state def new_node_attrs(self, state): return jnp.array([state.bias_init_mean, state.response_init_mean, diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py index a3dc7ce..d82d7bf 100644 --- a/algorithm/neat/gene/recurrent.py +++ b/algorithm/neat/gene/recurrent.py @@ -19,7 +19,7 @@ class RecurrentGeneConfig(NormalGeneConfig): class RecurrentGene(NormalGene): - def __init__(self, config: RecurrentGeneConfig): + def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()): self.config = config super().__init__(config) diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 818e9dd..330d819 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -28,9 +28,9 @@ class NEAT(Algorithm): state = state.update( P=self.config.basic.pop_size, - N=self.config.neat.maximum_nodes, - C=self.config.neat.maximum_conns, - S=self.config.neat.maximum_species, + N=self.config.neat.max_nodes, + C=self.config.neat.max_conns, + S=self.config.neat.max_species, NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes max_stagnation=self.config.neat.max_stagnation, @@ -80,6 +80,8 @@ class NEAT(Algorithm): return state.pop_genomes def tell_algorithm(self, state: State, fitness): + state = self.gene.update(state) + k1, k2, randkey = jax.random.split(state.randkey, 3) state = state.update( diff --git a/config/config.py b/config/config.py index c147d8a..87c55c2 100644 --- a/config/config.py +++ b/config/config.py @@ -17,9 +17,9 @@ class NeatConfig: network_type: str = "feedforward" inputs: int = 2 outputs: int = 1 - maximum_nodes: int = 50 - maximum_conns: int = 100 - maximum_species: int = 10 + max_nodes: int = 50 + max_conns: int = 100 + max_species: int = 10 # genome config compatibility_disjoint: float = 1 @@ -44,9 +44,9 @@ class NeatConfig: assert self.inputs > 0, "the inputs number of neat must be greater than 0" assert self.outputs > 0, "the outputs number of neat must be greater than 0" - assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0" - assert self.maximum_conns > 0, "the maximum connections must be greater than 0" - assert self.maximum_species > 0, "the maximum species must be greater than 0" + assert self.max_nodes > 0, "the maximum nodes must be greater than 0" + assert self.max_conns > 0, "the maximum connections must be greater than 0" + assert self.max_species > 0, "the maximum species must be greater than 0" assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0" assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0" @@ -101,7 +101,7 @@ class ProblemConfig: class Config: basic: BasicConfig = BasicConfig() neat: NeatConfig = NeatConfig() - hyper_neat: HyperNeatConfig = HyperNeatConfig() + hyperneat: HyperNeatConfig = HyperNeatConfig() gene: GeneConfig = GeneConfig() substrate: SubstrateConfig = SubstrateConfig() problem: ProblemConfig = ProblemConfig() diff --git a/core/gene.py b/core/gene.py index 2643c69..7d58b07 100644 --- a/core/gene.py +++ b/core/gene.py @@ -6,7 +6,7 @@ class Gene: node_attrs = [] conn_attrs = [] - def __init__(self, config: GeneConfig): + def __init__(self, config: GeneConfig = GeneConfig()): raise NotImplementedError def setup(self, state=State()): diff --git a/core/genome.py b/core/genome.py index 0153bca..bf42d42 100644 --- a/core/genome.py +++ b/core/genome.py @@ -19,6 +19,11 @@ class Genome: def __getitem__(self, idx): return self.__class__(self.nodes[idx], self.conns[idx]) + def __eq__(self, other): + nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes))) + conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns))) + return nodes_eq & conns_eq + def set(self, idx, value: Genome): return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns)) @@ -83,4 +88,3 @@ class Genome: @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) - diff --git a/core/problem.py b/core/problem.py index 3c97d73..c70c409 100644 --- a/core/problem.py +++ b/core/problem.py @@ -1,15 +1,27 @@ from typing import Callable + from config import ProblemConfig -from state import State +from .state import State class Problem: - def __init__(self, config: ProblemConfig): + def __init__(self, problem_config: ProblemConfig = ProblemConfig()): + self.config = problem_config + + def evaluate(self, randkey, state: State, act_func: Callable, params): raise NotImplementedError - def setup(self, state=State()): + @property + def input_shape(self): raise NotImplementedError - def evaluate(self, state: State, act_func: Callable, params): + @property + def output_shape(self): + raise NotImplementedError + + def show(self, randkey, state: State, act_func: Callable, params): + """ + show how a genome perform in this problem + """ raise NotImplementedError diff --git a/core/substrate.py b/core/substrate.py index 0de89c6..03faa86 100644 --- a/core/substrate.py +++ b/core/substrate.py @@ -4,7 +4,5 @@ from config import SubstrateConfig class Substrate: @staticmethod - def setup(state, config: SubstrateConfig): + def setup(state, config: SubstrateConfig = SubstrateConfig()): return state - - diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py new file mode 100644 index 0000000..a2d45ee --- /dev/null +++ b/examples/func_fit/xor.py @@ -0,0 +1,36 @@ +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.func_fit import XOR, FuncFitConfig + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + seed=42, + fitness_target=-1e-2, + pop_size=10000 + ), + neat=NeatConfig( + max_nodes=50, + max_conns=100, + max_species=30, + conn_add=0.8, + conn_delete=0, + node_add=0.4, + node_delete=0, + inputs=2, + outputs=1 + ), + gene=NormalGeneConfig(), + problem=FuncFitConfig( + error_method='rmse' + ) + ) + + algorithm = NEAT(config, NormalGene) + pipeline = Pipeline(config, algorithm, XOR) + state = pipeline.setup() + pipeline.pre_compile(state) + state, best = pipeline.auto_run(state) + pipeline.show(state, best) diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py new file mode 100644 index 0000000..23fc389 --- /dev/null +++ b/examples/func_fit/xor_hyperneat.py @@ -0,0 +1,40 @@ +from config import * +from pipeline import Pipeline +from algorithm.neat import NormalGene, NormalGeneConfig +from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig +from problem.func_fit import XOR3d, FuncFitConfig + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + seed=42, + fitness_target=0, + pop_size=1000 + ), + neat=NeatConfig( + max_nodes=50, + max_conns=100, + max_species=30, + inputs=4, + outputs=1 + ), + hyperneat=HyperNeatConfig( + inputs=3, + outputs=1 + ), + substrate=NormalSubstrateConfig( + input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)), + ), + gene=NormalGeneConfig( + activation_default='tanh', + activation_options=('tanh', ), + ), + problem=FuncFitConfig() + ) + + algorithm = HyperNEAT(config, NormalGene, NormalSubstrate) + pipeline = Pipeline(config, algorithm, XOR3d) + state = pipeline.setup() + state, best = pipeline.auto_run(state) + pipeline.show(state, best) diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py new file mode 100644 index 0000000..2eed951 --- /dev/null +++ b/examples/func_fit/xor_recurrent.py @@ -0,0 +1,40 @@ +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig +from problem.func_fit import XOR3d, FuncFitConfig + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + seed=42, + fitness_target=-1e-2, + generation_limit=300, + pop_size=1000 + ), + neat=NeatConfig( + network_type="recurrent", + max_nodes=50, + max_conns=100, + max_species=30, + conn_add=0.5, + conn_delete=0.5, + node_add=0.4, + node_delete=0.4, + inputs=3, + outputs=1 + ), + gene=RecurrentGeneConfig( + activate_times=10 + ), + problem=FuncFitConfig( + error_method='rmse' + ) + ) + + algorithm = NEAT(config, RecurrentGene) + pipeline = Pipeline(config, algorithm, XOR3d) + state = pipeline.setup() + state, best = pipeline.auto_run(state) + pipeline.show(state, best) diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py new file mode 100644 index 0000000..3ce61c5 --- /dev/null +++ b/examples/gymnax/cartpole.py @@ -0,0 +1,84 @@ +import jax.numpy as jnp + +from config import * +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from problem.rl_env import GymNaxConfig, GymNaxEnv + + +def example_conf1(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default='sigmoid', + activation_options=('sigmoid',), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1} + ) + ) + + +def example_conf2(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=1, + ), + gene=NormalGeneConfig( + activation_default='tanh', + activation_options=('tanh',), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1} + ) + ) + + +def example_conf3(): + return Config( + basic=BasicConfig( + seed=42, + fitness_target=500, + pop_size=10000 + ), + neat=NeatConfig( + inputs=4, + outputs=2, + ), + gene=NormalGeneConfig( + activation_default='tanh', + activation_options=('tanh',), + ), + problem=GymNaxConfig( + env_name='CartPole-v1', + output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + ) + ) + + +if __name__ == '__main__': + # all config files above can solve cartpole + conf = example_conf3() + + algorithm = NEAT(conf, NormalGene) + pipeline = Pipeline(conf, algorithm, GymNaxEnv) + state = pipeline.setup() + state, best = pipeline.auto_run(state) + pipeline.show(state, best) diff --git a/examples/test.py b/examples/test.py deleted file mode 100644 index 8761cc2..0000000 --- a/examples/test.py +++ /dev/null @@ -1,31 +0,0 @@ -from functools import partial -import jax - -from utils import unflatten_conns, act, agg, Activation, Aggregation -from algorithm.neat.gene import RecurrentGeneConfig - -config = RecurrentGeneConfig( - activation_options=("tanh", "sigmoid"), - activation_default="tanh", -) - - -class A: - def __init__(self): - self.act_funcs = [Activation.name2func[name] for name in config.activation_options] - self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] - self.isTrue = False - - @partial(jax.jit, static_argnums=(0,)) - def step(self): - i = jax.numpy.array([0, 1]) - z = jax.numpy.array([ - [1, 1], - [2, 2] - ]) - print(self.act_funcs) - return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs) - - -AA = A() -print(AA.step()) diff --git a/examples/xor.py b/examples/xor.py deleted file mode 100644 index 920c6b1..0000000 --- a/examples/xor.py +++ /dev/null @@ -1,40 +0,0 @@ -import jax -import numpy as np - -from config import Config, BasicConfig, NeatConfig -from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - fitness_target=3.9999999, - pop_size=10000 - ), - neat=NeatConfig( - maximum_nodes=50, - maximum_conns=100, - compatibility_threshold=4 - ), - gene=NormalGeneConfig() - ) - - algorithm = NEAT(config, NormalGene) - pipeline = Pipeline(config, algorithm) - pipeline.auto_run(evaluate) diff --git a/examples/xor_hyperneat.py b/examples/xor_hyperneat.py deleted file mode 100644 index 9cf8245..0000000 --- a/examples/xor_hyperneat.py +++ /dev/null @@ -1,49 +0,0 @@ -import jax -import numpy as np - -from config import Config, BasicConfig, NeatConfig -from pipeline import Pipeline -from algorithm import NEAT, HyperNEAT -from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig -from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - fitness_target=3.99999, - pop_size=10000 - ), - neat=NeatConfig( - network_type="recurrent", - maximum_nodes=50, - maximum_conns=100, - inputs=4, - outputs=1 - - ), - gene=RecurrentGeneConfig( - activation_default="tanh", - activation_options=("tanh",), - ), - substrate=NormalSubstrateConfig(), - ) - neat = NEAT(config, RecurrentGene) - hyperNEAT = HyperNEAT(config, neat, NormalSubstrate) - - pipeline = Pipeline(config, hyperNEAT) - pipeline.auto_run(evaluate) diff --git a/examples/xor_recurrent.py b/examples/xor_recurrent.py deleted file mode 100644 index b9ae0bc..0000000 --- a/examples/xor_recurrent.py +++ /dev/null @@ -1,42 +0,0 @@ -import jax -import numpy as np - -from config import Config, BasicConfig, NeatConfig -from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig - - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - fitness_target=3.99999, - pop_size=10000 - ), - neat=NeatConfig( - network_type="recurrent", - maximum_nodes=50, - maximum_conns=100 - ), - gene=RecurrentGeneConfig( - activate_times=3 - ) - ) - algorithm = NEAT(config, RecurrentGene) - pipeline = Pipeline(config, algorithm) - pipeline.auto_run(evaluate) diff --git a/pipeline.py b/pipeline.py index a4f9c9f..0e4ca06 100644 --- a/pipeline.py +++ b/pipeline.py @@ -1,83 +1,115 @@ -import time -from typing import Union, Callable +from functools import partial +from typing import Type import jax -from jax import vmap, jit +import time import numpy as np +from algorithm import NEAT, HyperNEAT from config import Config -from core import Algorithm, Genome +from core import State, Algorithm, Problem class Pipeline: - """ - Simple pipeline. - """ - def __init__(self, config: Config, algorithm: Algorithm): + def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]): self.config = config self.algorithm = algorithm + self.problem = problem_type(config.problem) - randkey = jax.random.PRNGKey(config.basic.seed) - self.state = algorithm.setup(randkey) + if isinstance(algorithm, NEAT): + assert config.neat.inputs == self.problem.input_shape[-1] + + elif isinstance(algorithm, HyperNEAT): + assert config.hyperneat.inputs == self.problem.input_shape[-1] + + else: + raise NotImplementedError + + self.act_func = self.algorithm.act + + for _ in range(len(self.problem.input_shape) - 1): + self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) self.best_genome = None self.best_fitness = float('-inf') - self.generation_timestamp = time.time() + self.generation_timestamp = None - self.evaluate_time = 0 + def setup(self): + key = jax.random.PRNGKey(self.config.basic.seed) + algorithm_key, evaluate_key = jax.random.split(key, 2) + state = State() + state = self.algorithm.setup(algorithm_key, state) + return state.update( + evaluate_key=evaluate_key + ) - self.act_func = jit(self.algorithm.act) - self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None))) - self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0))) - self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0))) - self.tell_func = jit(self.algorithm.tell) + @partial(jax.jit, static_argnums=(0,)) + def step(self, state): - def ask(self): - pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes) - return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms) + key, sub_key = jax.random.split(state.evaluate_key) + keys = jax.random.split(key, self.config.basic.pop_size) - def tell(self, fitness): - # self.state = self.tell_func(self.state, fitness) - new_state = self.tell_func(self.state, fitness) - self.state = new_state + pop = self.algorithm.ask(state) - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): + pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) + + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, + pop_transformed) + + state = self.algorithm.tell(state, fitnesses) + + return state.update(evaluate_key=sub_key), fitnesses + + def auto_run(self, ini_state): + state = ini_state for _ in range(self.config.basic.generation_limit): - forward_func = self.ask() - fitnesses = fitness_func(forward_func) + self.generation_timestamp = time.time() - if analysis is not None: - if analysis == "default": - self.default_analysis(fitnesses) - else: - assert callable(analysis), f"What the fuck you passed in? A {analysis}?" - analysis(fitnesses) + previous_pop = self.algorithm.ask(state) + + state, fitnesses = self.step(state) + + fitnesses = jax.device_get(fitnesses) + + self.analysis(state, previous_pop, fitnesses) if max(fitnesses) >= self.config.basic.fitness_target: print("Fitness limit reached!") - return self.best_genome + return state, self.best_genome - self.tell(fitnesses) print("Generation limit reached!") - return self.best_genome + return state, self.best_genome + + def analysis(self, state, pop, fitnesses): - def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) new_timestamp = time.time() + cost_time = new_timestamp - self.generation_timestamp - self.generation_timestamp = new_timestamp max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] - self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx]) + self.best_genome = pop[max_idx] - member_count = jax.device_get(self.state.species_info.member_count) + member_count = jax.device_get(state.species_info.member_count) species_sizes = [int(i) for i in member_count if i > 0] - print(f"Generation: {self.state.generation}", + print(f"Generation: {state.generation}", f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") \ No newline at end of file + f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") + + def show(self, state, genome): + transformed = self.algorithm.transform(state, genome) + self.problem.show(state.evaluate_key, state, self.act_func, transformed) + + def pre_compile(self, state): + tic = time.time() + print("start compile") + self.step.lower(self, state).compile() + # compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile() + # self.__dict__['step'] = compiled_step + print(f"compile finished, cost time: {time.time() - tic}s") diff --git a/problem/func_fit/__init__.py b/problem/func_fit/__init__.py index e69de29..9304943 100644 --- a/problem/func_fit/__init__.py +++ b/problem/func_fit/__init__.py @@ -0,0 +1,3 @@ +from .func_fit import FuncFit, FuncFitConfig +from .xor import XOR +from .xor3d import XOR3d diff --git a/problem/func_fit/func_fit.py b/problem/func_fit/func_fit.py new file mode 100644 index 0000000..3e904fd --- /dev/null +++ b/problem/func_fit/func_fit.py @@ -0,0 +1,69 @@ +from typing import Callable +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from config import ProblemConfig +from core import Problem, State + + +@dataclass(frozen=True) +class FuncFitConfig(ProblemConfig): + error_method: str = 'mse' + + def __post_init__(self): + assert self.error_method in {'mse', 'rmse', 'mae', 'mape'} + + +class FuncFit(Problem): + + def __init__(self, config: FuncFitConfig = FuncFitConfig()): + self.config = config + super().__init__(config) + + def evaluate(self, randkey, state: State, act_func: Callable, params): + + predict = act_func(state, self.inputs, params) + + if self.config.error_method == 'mse': + loss = jnp.mean((predict - self.targets) ** 2) + + elif self.config.error_method == 'rmse': + loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2)) + + elif self.config.error_method == 'mae': + loss = jnp.mean(jnp.abs(predict - self.targets)) + + elif self.config.error_method == 'mape': + loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets)) + + else: + raise NotImplementedError + + return -loss + + def show(self, randkey, state: State, act_func: Callable, params): + predict = act_func(state, self.inputs, params) + inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) + loss = -self.evaluate(randkey, state, act_func, params) + msg = "" + for i in range(inputs.shape[0]): + msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n" + msg += f"loss: {loss}\n" + print(msg) + + @property + def inputs(self): + raise NotImplementedError + + @property + def targets(self): + raise NotImplementedError + + @property + def input_shape(self): + raise NotImplementedError + + @property + def output_shape(self): + raise NotImplementedError diff --git a/problem/func_fit/func_fitting.py b/problem/func_fit/func_fitting.py deleted file mode 100644 index a60be5e..0000000 --- a/problem/func_fit/func_fitting.py +++ /dev/null @@ -1,21 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -from config import ProblemConfig -from core import Problem, State - - -@dataclass(frozen=True) -class FuncFitConfig: - pass - - -class FuncFit(Problem): - def __init__(self, config: ProblemConfig): - self.config = ProblemConfig - - def setup(self, state=State()): - pass - - def evaluate(self, state: State, act_func: Callable, params): - pass \ No newline at end of file diff --git a/problem/func_fit/xor.py b/problem/func_fit/xor.py new file mode 100644 index 0000000..c41cc65 --- /dev/null +++ b/problem/func_fit/xor.py @@ -0,0 +1,36 @@ +import numpy as np + +from .func_fit import FuncFit, FuncFitConfig + + +class XOR(FuncFit): + + def __init__(self, config: FuncFitConfig = FuncFitConfig()): + self.config = config + super().__init__(config) + + @property + def inputs(self): + return np.array([ + [0, 0], + [0, 1], + [1, 0], + [1, 1] + ]) + + @property + def targets(self): + return np.array([ + [0], + [1], + [1], + [0] + ]) + + @property + def input_shape(self): + return (4, 2) + + @property + def output_shape(self): + return (4, 1) diff --git a/problem/func_fit/xor3d.py b/problem/func_fit/xor3d.py new file mode 100644 index 0000000..2f070f8 --- /dev/null +++ b/problem/func_fit/xor3d.py @@ -0,0 +1,44 @@ +import numpy as np + +from .func_fit import FuncFit, FuncFitConfig + + +class XOR3d(FuncFit): + + def __init__(self, config: FuncFitConfig = FuncFitConfig()): + self.config = config + super().__init__(config) + + @property + def inputs(self): + return np.array([ + [0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [0, 1, 1], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0], + [1, 1, 1], + ]) + + @property + def targets(self): + return np.array([ + [0], + [1], + [1], + [0], + [1], + [0], + [0], + [1] + ]) + + @property + def input_shape(self): + return (8, 3) + + @property + def output_shape(self): + return (8, 1) diff --git a/problem/rl_env/__init__.py b/problem/rl_env/__init__.py new file mode 100644 index 0000000..63e273a --- /dev/null +++ b/problem/rl_env/__init__.py @@ -0,0 +1 @@ +from .gymnax_env import GymNaxEnv, GymNaxConfig diff --git a/problem/rl_env/gymnax_env.py b/problem/rl_env/gymnax_env.py new file mode 100644 index 0000000..f945e1d --- /dev/null +++ b/problem/rl_env/gymnax_env.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from typing import Callable + +import jax +import jax.numpy as jnp +import gymnax + +from core import State +from .rl_env import RLEnv, RLEnvConfig + + +@dataclass(frozen=True) +class GymNaxConfig(RLEnvConfig): + env_name: str = "CartPole-v1" + + def __post_init__(self): + assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered" + + +class GymNaxEnv(RLEnv): + + def __init__(self, config: GymNaxConfig = GymNaxConfig()): + super().__init__(config) + self.config = config + self.env, self.env_params = gymnax.make(config.env_name) + + def env_step(self, randkey, env_state, action): + return self.env.step(randkey, env_state, action, self.env_params) + + def env_reset(self, randkey): + return self.env.reset(randkey, self.env_params) + + @property + def input_shape(self): + return self.env.observation_space(self.env_params).shape + + @property + def output_shape(self): + return self.env.action_space(self.env_params).shape + + def show(self, randkey, state: State, act_func: Callable, params): + raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/problem/rl_env/rl_env.py b/problem/rl_env/rl_env.py new file mode 100644 index 0000000..2eaa9a0 --- /dev/null +++ b/problem/rl_env/rl_env.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from typing import Callable +from functools import partial + +import jax + +from config import ProblemConfig + +from core import Problem, State + + +@dataclass(frozen=True) +class RLEnvConfig(ProblemConfig): + output_transform: Callable = lambda x: x + + +class RLEnv(Problem): + + def __init__(self, config: RLEnvConfig = RLEnvConfig()): + super().__init__(config) + self.config = config + + def evaluate(self, randkey, state: State, act_func: Callable, params): + rng_reset, rng_episode = jax.random.split(randkey) + init_obs, init_env_state = self.reset(rng_reset) + + def cond_func(carry): + _, _, _, done, _ = carry + return ~done + + def body_func(carry): + obs, env_state, rng, _, tr = carry # total reward + net_out = act_func(state, obs, params) + action = self.config.output_transform(net_out) + next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) + next_rng, _ = jax.random.split(rng) + return next_obs, next_env_state, next_rng, done, tr + reward + + _, _, _, _, total_reward = jax.lax.while_loop( + cond_func, + body_func, + (init_obs, init_env_state, rng_episode, False, 0.0) + ) + + return total_reward + + @partial(jax.jit, static_argnums=(0,)) + def step(self, randkey, env_state, action): + return self.env_step(randkey, env_state, action) + + @partial(jax.jit, static_argnums=(0,)) + def reset(self, randkey): + return self.env_reset(randkey) + + def env_step(self, randkey, env_state, action): + raise NotImplementedError + + def env_reset(self, randkey): + raise NotImplementedError + + @property + def input_shape(self): + raise NotImplementedError + + @property + def output_shape(self): + raise NotImplementedError + + def show(self, randkey, state: State, act_func: Callable, params): + raise NotImplementedError