From c7fb1ddabe8759f30306e689a2a0851795641e9f Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 2 Aug 2023 15:02:08 +0800 Subject: [PATCH] remove create_func.... --- algorithm/__init__.py | 1 + algorithm/hyperneat/__init__.py | 2 + algorithm/hyperneat/hyperneat.py | 116 ++++++++++++++++++++++ algorithm/hyperneat/substrate/__init__.py | 2 + algorithm/hyperneat/substrate/normal.py | 25 +++++ algorithm/hyperneat/substrate/tools.py | 49 +++++++++ algorithm/neat/gene/__init__.py | 2 + algorithm/neat/gene/normal.py | 4 +- algorithm/neat/gene/recurrent.py | 57 +++++++++++ algorithm/neat/neat.py | 6 +- config/config.py | 6 ++ core/__init__.py | 3 +- core/gene.py | 3 + core/problem.py | 15 +++ core/substrate.py | 2 + examples/test.py | 29 ++++-- examples/xor.py | 12 ++- examples/xor_hyperneat.py | 49 +++++++++ examples/xor_recurrent.py | 42 ++++++++ problem/__init__.py | 0 problem/func_fit/__init__.py | 0 problem/func_fit/func_fitting.py | 21 ++++ 22 files changed, 425 insertions(+), 21 deletions(-) create mode 100644 algorithm/hyperneat/__init__.py create mode 100644 algorithm/hyperneat/hyperneat.py create mode 100644 algorithm/hyperneat/substrate/__init__.py create mode 100644 algorithm/hyperneat/substrate/normal.py create mode 100644 algorithm/hyperneat/substrate/tools.py create mode 100644 algorithm/neat/gene/recurrent.py create mode 100644 core/problem.py create mode 100644 examples/xor_hyperneat.py create mode 100644 examples/xor_recurrent.py create mode 100644 problem/__init__.py create mode 100644 problem/func_fit/__init__.py create mode 100644 problem/func_fit/func_fitting.py diff --git a/algorithm/__init__.py b/algorithm/__init__.py index 6fe56c9..e2e54c0 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1 +1,2 @@ from .neat import NEAT +from .hyperneat import HyperNEAT diff --git a/algorithm/hyperneat/__init__.py b/algorithm/hyperneat/__init__.py new file mode 100644 index 0000000..8d106fb --- /dev/null +++ b/algorithm/hyperneat/__init__.py @@ -0,0 +1,2 @@ +from .hyperneat import HyperNEAT +from .substrate import NormalSubstrate, NormalSubstrateConfig diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py new file mode 100644 index 0000000..5695edc --- /dev/null +++ b/algorithm/hyperneat/hyperneat.py @@ -0,0 +1,116 @@ +from typing import Type + +import jax +from jax import numpy as jnp, Array, vmap +import numpy as np + +from config import Config, HyperNeatConfig +from core import Algorithm, Substrate, State, Genome +from utils import Activation, Aggregation +from algorithm.neat import NEAT +from .substrate import analysis_substrate + + +class HyperNEAT(Algorithm): + + def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]): + self.config = config + self.neat = neat + self.substrate = substrate + + def setup(self, randkey, state=State()): + neat_key, randkey = jax.random.split(randkey) + state = state.update( + below_threshold=self.config.hyper_neat.below_threshold, + max_weight=self.config.hyper_neat.max_weight, + ) + state = self.neat.setup(neat_key, state) + state = self.substrate.setup(self.config.substrate, state) + + assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias + assert self.config.hyper_neat.outputs == state.output_coors.shape[0] + + h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) + h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] + h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) + h_conns[:, 0:2] = correspond_keys + + state = state.update( + h_input_idx=h_input_idx, + h_output_idx=h_output_idx, + h_hidden_idx=h_hidden_idx, + h_nodes=h_nodes, + h_conns=h_conns, + query_coors=query_coors, + ) + + return state + + def ask_algorithm(self, state: State): + return state.pop_genomes + + def tell_algorithm(self, state: State, fitness): + return self.neat.tell(state, fitness) + + def forward(self, state, inputs: Array, transformed: Array): + return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed) + + def forward_transform(self, state: State, genome: Genome): + t = self.neat.forward_transform(state, genome) + query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t) + + # mute the connection with weight below threshold + query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res) + + # make query res in range [-max_weight, max_weight] + query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res) + query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res) + query_res = query_res / (1 - state.below_threshold) * state.max_weight + + h_conns = state.h_conns.at[:, 2:].set(query_res) + return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns)) + + +class HyperNEATGene: + node_attrs = [] # no node attributes + conn_attrs = ['weight'] + + @staticmethod + def forward_transform(genome: Genome): + N = genome.nodes.shape[0] + u_conns = jnp.zeros((N, N), dtype=jnp.float32) + + in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32) + out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32) + weights = genome.conns[:, 2] + + u_conns = u_conns.at[in_keys, out_keys].set(weights) + return genome.nodes, u_conns + + @staticmethod + def forward(config: HyperNeatConfig, state: State, inputs, transformed): + act = Activation.name2func[config.activation] + agg = Aggregation.name2func[config.aggregation] + + batch_act, batch_agg = jax.vmap(act), jax.vmap(agg) + + nodes, weights = transformed + + inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) + + input_idx = state.h_input_idx + output_idx = state.h_output_idx + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + def body_func(i, values): + values = values.at[input_idx].set(inputs_with_bias) + nodes_ins = values * weights.T + values = batch_agg(nodes_ins) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(values) # z = act(z) + return values + + vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) + return vals[output_idx] diff --git a/algorithm/hyperneat/substrate/__init__.py b/algorithm/hyperneat/substrate/__init__.py new file mode 100644 index 0000000..035c39f --- /dev/null +++ b/algorithm/hyperneat/substrate/__init__.py @@ -0,0 +1,2 @@ +from .normal import NormalSubstrate, NormalSubstrateConfig +from .tools import analysis_substrate \ No newline at end of file diff --git a/algorithm/hyperneat/substrate/normal.py b/algorithm/hyperneat/substrate/normal.py new file mode 100644 index 0000000..7484fcb --- /dev/null +++ b/algorithm/hyperneat/substrate/normal.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import Tuple + +import numpy as np + +from core import Substrate, State +from config import SubstrateConfig + + +@dataclass(frozen=True) +class NormalSubstrateConfig(SubstrateConfig): + input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1)) + hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0)) + output_coors: Tuple[Tuple[float]] = ((0, 1),) + + +class NormalSubstrate(Substrate): + + @staticmethod + def setup(config: NormalSubstrateConfig, state: State = State()): + return state.update( + input_coors=np.asarray(config.input_coors, dtype=np.float32), + output_coors=np.asarray(config.output_coors, dtype=np.float32), + hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32), + ) diff --git a/algorithm/hyperneat/substrate/tools.py b/algorithm/hyperneat/substrate/tools.py new file mode 100644 index 0000000..8bc4959 --- /dev/null +++ b/algorithm/hyperneat/substrate/tools.py @@ -0,0 +1,49 @@ +import numpy as np + + +def analysis_substrate(state): + cd = state.input_coors.shape[1] # coordinate dimensions + si = state.input_coors.shape[0] # input coordinate size + so = state.output_coors.shape[0] # output coordinate size + sh = state.hidden_coors.shape[0] # hidden coordinate size + + input_idx = np.arange(si) + output_idx = np.arange(si, si + so) + hidden_idx = np.arange(si + so, si + so + sh) + + total_conns = si * sh + sh * sh + sh * so + query_coors = np.zeros((total_conns, cd * 2)) + correspond_keys = np.zeros((total_conns, 2)) + + # connect input to hidden + aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) + query_coors[0: si * sh, :] = aux_coors + correspond_keys[0: si * sh, :] = aux_keys + + # connect hidden to hidden + aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) + query_coors[si * sh: si * sh + sh * sh, :] = aux_coors + correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys + + # connect hidden to output + aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) + query_coors[si * sh + sh * sh:, :] = aux_coors + correspond_keys[si * sh + sh * sh:, :] = aux_keys + + return input_idx, output_idx, hidden_idx, query_coors, correspond_keys + + +def cartesian_product(keys1, keys2, coors1, coors2): + len1 = keys1.shape[0] + len2 = keys2.shape[0] + + repeated_coors1 = np.repeat(coors1, len2, axis=0) + repeated_keys1 = np.repeat(keys1, len2) + + tiled_coors2 = np.tile(coors2, (len1, 1)) + tiled_keys2 = np.tile(keys2, len1) + + new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) + correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) + + return new_coors, correspond_keys diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index 02af6ce..49a05b9 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1 +1,3 @@ from .normal import NormalGene, NormalGeneConfig +from .recurrent import RecurrentGene, RecurrentGeneConfig + diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 449eda2..84973c9 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -24,11 +24,11 @@ class NormalGeneConfig(GeneConfig): response_replace_rate: float = 0.1 activation_default: str = 'sigmoid' - activation_options: Tuple[str] = ('sigmoid',) + activation_options: Tuple = ('sigmoid',) activation_replace_rate: float = 0.1 aggregation_default: str = 'sum' - aggregation_options: Tuple[str] = ('sum',) + aggregation_options: Tuple = ('sum',) aggregation_replace_rate: float = 0.1 weight_init_mean: float = 0.0 diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py new file mode 100644 index 0000000..a3dc7ce --- /dev/null +++ b/algorithm/neat/gene/recurrent.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +import jax +from jax import numpy as jnp, vmap + +from .normal import NormalGene, NormalGeneConfig +from core import State, Genome +from utils import unflatten_conns, act, agg + + +@dataclass(frozen=True) +class RecurrentGeneConfig(NormalGeneConfig): + activate_times: int = 10 + + def __post_init__(self): + super().__post_init__() + assert self.activate_times > 0 + + +class RecurrentGene(NormalGene): + + def __init__(self, config: RecurrentGeneConfig): + self.config = config + super().__init__(config) + + def forward_transform(self, state: State, genome: Genome): + u_conns = unflatten_conns(genome.nodes, genome.conns) + + # remove un-enable connections and remove enable attr + conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + + return genome.nodes, u_conns + + def forward(self, state: State, inputs, transformed): + nodes, conns = transformed + + batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None)) + + input_idx = state.input_idx + output_idx = state.output_idx + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + weights = conns[0, :] + + def body_func(i, values): + values = values.at[input_idx].set(inputs) + nodes_ins = values * weights.T + values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z) + return values + + vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals) + return vals[output_idx] diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index f093ed4..818e9dd 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -1,3 +1,5 @@ +from typing import Type + import jax from jax import numpy as jnp import numpy as np @@ -10,9 +12,9 @@ from .species import SpeciesInfo, update_species, speciate class NEAT(Algorithm): - def __init__(self, config: Config, gene: Gene): + def __init__(self, config: Config, gene_type: Type[Gene]): self.config = config - self.gene = gene + self.gene = gene_type(config.gene) self.forward_func = None self.tell_func = None diff --git a/config/config.py b/config/config.py index d68accd..c147d8a 100644 --- a/config/config.py +++ b/config/config.py @@ -92,6 +92,11 @@ class SubstrateConfig: pass +@dataclass(frozen=True) +class ProblemConfig: + pass + + @dataclass(frozen=True) class Config: basic: BasicConfig = BasicConfig() @@ -99,3 +104,4 @@ class Config: hyper_neat: HyperNeatConfig = HyperNeatConfig() gene: GeneConfig = GeneConfig() substrate: SubstrateConfig = SubstrateConfig() + problem: ProblemConfig = ProblemConfig() diff --git a/core/__init__.py b/core/__init__.py index ad0ee9c..12c9675 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -2,4 +2,5 @@ from .algorithm import Algorithm from .state import State from .genome import Genome from .gene import Gene -from .substrate import Substrate \ No newline at end of file +from .substrate import Substrate +from .problem import Problem diff --git a/core/gene.py b/core/gene.py index 7c6d04e..2643c69 100644 --- a/core/gene.py +++ b/core/gene.py @@ -6,6 +6,9 @@ class Gene: node_attrs = [] conn_attrs = [] + def __init__(self, config: GeneConfig): + raise NotImplementedError + def setup(self, state=State()): raise NotImplementedError diff --git a/core/problem.py b/core/problem.py new file mode 100644 index 0000000..3c97d73 --- /dev/null +++ b/core/problem.py @@ -0,0 +1,15 @@ +from typing import Callable +from config import ProblemConfig +from state import State + + +class Problem: + + def __init__(self, config: ProblemConfig): + raise NotImplementedError + + def setup(self, state=State()): + raise NotImplementedError + + def evaluate(self, state: State, act_func: Callable, params): + raise NotImplementedError diff --git a/core/substrate.py b/core/substrate.py index e9694d1..0de89c6 100644 --- a/core/substrate.py +++ b/core/substrate.py @@ -6,3 +6,5 @@ class Substrate: @staticmethod def setup(state, config: SubstrateConfig): return state + + diff --git a/examples/test.py b/examples/test.py index 8eef82a..8761cc2 100644 --- a/examples/test.py +++ b/examples/test.py @@ -1,24 +1,31 @@ from functools import partial import jax +from utils import unflatten_conns, act, agg, Activation, Aggregation +from algorithm.neat.gene import RecurrentGeneConfig + +config = RecurrentGeneConfig( + activation_options=("tanh", "sigmoid"), + activation_default="tanh", +) + class A: def __init__(self): - self.a = 1 - self.b = 2 + self.act_funcs = [Activation.name2func[name] for name in config.activation_options] + self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] self.isTrue = False @partial(jax.jit, static_argnums=(0,)) def step(self): - if self.isTrue: - return self.a + 1 - else: - return self.b + 1 + i = jax.numpy.array([0, 1]) + z = jax.numpy.array([ + [1, 1], + [2, 2] + ]) + print(self.act_funcs) + return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs) AA = A() -print(AA.step(), hash(AA)) -print(AA.step(), hash(AA)) -print(AA.step(), hash(AA)) -AA.a = (2, 3, 4) -print(AA.step(), hash(AA)) +print(AA.step()) diff --git a/examples/xor.py b/examples/xor.py index 6ee22b9..920c6b1 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -28,11 +28,13 @@ if __name__ == '__main__': pop_size=10000 ), neat=NeatConfig( - maximum_nodes=20, - maximum_conns=50, - ) + maximum_nodes=50, + maximum_conns=100, + compatibility_threshold=4 + ), + gene=NormalGeneConfig() ) - normal_gene = NormalGene(NormalGeneConfig()) - algorithm = NEAT(config, normal_gene) + + algorithm = NEAT(config, NormalGene) pipeline = Pipeline(config, algorithm) pipeline.auto_run(evaluate) diff --git a/examples/xor_hyperneat.py b/examples/xor_hyperneat.py new file mode 100644 index 0000000..9cf8245 --- /dev/null +++ b/examples/xor_hyperneat.py @@ -0,0 +1,49 @@ +import jax +import numpy as np + +from config import Config, BasicConfig, NeatConfig +from pipeline import Pipeline +from algorithm import NEAT, HyperNEAT +from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig +from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + fitness_target=3.99999, + pop_size=10000 + ), + neat=NeatConfig( + network_type="recurrent", + maximum_nodes=50, + maximum_conns=100, + inputs=4, + outputs=1 + + ), + gene=RecurrentGeneConfig( + activation_default="tanh", + activation_options=("tanh",), + ), + substrate=NormalSubstrateConfig(), + ) + neat = NEAT(config, RecurrentGene) + hyperNEAT = HyperNEAT(config, neat, NormalSubstrate) + + pipeline = Pipeline(config, hyperNEAT) + pipeline.auto_run(evaluate) diff --git a/examples/xor_recurrent.py b/examples/xor_recurrent.py new file mode 100644 index 0000000..b9ae0bc --- /dev/null +++ b/examples/xor_recurrent.py @@ -0,0 +1,42 @@ +import jax +import numpy as np + +from config import Config, BasicConfig, NeatConfig +from pipeline import Pipeline +from algorithm import NEAT +from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig + + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + fitness_target=3.99999, + pop_size=10000 + ), + neat=NeatConfig( + network_type="recurrent", + maximum_nodes=50, + maximum_conns=100 + ), + gene=RecurrentGeneConfig( + activate_times=3 + ) + ) + algorithm = NEAT(config, RecurrentGene) + pipeline = Pipeline(config, algorithm) + pipeline.auto_run(evaluate) diff --git a/problem/__init__.py b/problem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/problem/func_fit/__init__.py b/problem/func_fit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/problem/func_fit/func_fitting.py b/problem/func_fit/func_fitting.py new file mode 100644 index 0000000..a60be5e --- /dev/null +++ b/problem/func_fit/func_fitting.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import Callable + +from config import ProblemConfig +from core import Problem, State + + +@dataclass(frozen=True) +class FuncFitConfig: + pass + + +class FuncFit(Problem): + def __init__(self, config: ProblemConfig): + self.config = ProblemConfig + + def setup(self, state=State()): + pass + + def evaluate(self, state: State, act_func: Callable, params): + pass \ No newline at end of file