diff --git a/algorithm/__init__.py b/algorithm/__init__.py index a9899b1..1bfd121 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1,3 +1,4 @@ +from .base import Algorithm from .state import State from .neat import NEAT -from .config import Configer \ No newline at end of file +from .hyperneat import HyperNEAT diff --git a/algorithm/base.py b/algorithm/base.py new file mode 100644 index 0000000..188e96d --- /dev/null +++ b/algorithm/base.py @@ -0,0 +1,17 @@ +from typing import Callable + +from .state import State + +EMPTY = lambda *args: args + + +class Algorithm: + + def __init__(self): + self.tell: Callable = EMPTY + self.ask: Callable = EMPTY + self.forward: Callable = EMPTY + self.forward_transform: Callable = EMPTY + + def setup(self, randkey, state=State()): + pass diff --git a/algorithm/hyperneat/__init__.py b/algorithm/hyperneat/__init__.py index e69de29..17af79f 100644 --- a/algorithm/hyperneat/__init__.py +++ b/algorithm/hyperneat/__init__.py @@ -0,0 +1,2 @@ +from .hyperneat import HyperNEAT +from .substrate import BaseSubstrate diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py new file mode 100644 index 0000000..79ff6e0 --- /dev/null +++ b/algorithm/hyperneat/hyperneat.py @@ -0,0 +1,70 @@ +from typing import Type + +import jax +import numpy as np + +from .substrate import BaseSubstrate, analysis_substrate +from .hyperneat_gene import HyperNEATGene +from algorithm import State, Algorithm, neat + + +class HyperNEAT(Algorithm): + + def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]): + super().__init__() + self.config = config + self.gene_type = gene_type + self.substrate = substrate + self.neat = neat.NEAT(config, gene_type) + + self.tell = create_tell(self.neat) + self.forward_transform = create_forward_transform(config, self.neat) + self.forward = HyperNEATGene.create_forward(config) + + def setup(self, randkey, state=State()): + state = state.update( + below_threshold=self.config['below_threshold'], + max_weight=self.config['max_weight'] + ) + + state = self.substrate.setup(state, self.config) + h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) + h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] + h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) + h_conns[:, 0:2] = correspond_keys + + state = state.update( + # h is short for hyperneat + h_input_idx=h_input_idx, + h_output_idx=h_output_idx, + h_hidden_idx=h_hidden_idx, + query_coors=query_coors, + correspond_keys=correspond_keys, + h_nodes=h_nodes, + h_conns=h_conns + ) + state = self.neat.setup(randkey, state=state) + + self.config['h_input_idx'] = h_input_idx + self.config['h_output_idx'] = h_output_idx + + return state + + +def create_tell(neat_instance): + def tell(state, fitness): + return neat_instance.tell(state, fitness) + + return tell + + +def create_forward_transform(config, neat_instance): + def forward_transform(state, nodes, conns): + t = neat_instance.forward_transform(state, nodes, conns) + batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None)) + query_res = batch_forward_func(state.query_coors, t) # hyperneat connections + h_nodes = state.h_nodes + h_conns = state.h_conns.at[:, 2:].set(query_res) + return HyperNEATGene.forward_transform(state, h_nodes, h_conns) + + return forward_transform diff --git a/algorithm/hyperneat/hyperneat_gene.py b/algorithm/hyperneat/hyperneat_gene.py new file mode 100644 index 0000000..247c95c --- /dev/null +++ b/algorithm/hyperneat/hyperneat_gene.py @@ -0,0 +1,54 @@ +import jax +from jax import numpy as jnp, vmap + +from algorithm.neat import BaseGene +from algorithm.neat.gene import Activation +from algorithm.neat.gene import Aggregation + + +class HyperNEATGene(BaseGene): + node_attrs = [] # no node attributes + conn_attrs = ['weight'] + + @staticmethod + def forward_transform(state, nodes, conns): + N = nodes.shape[0] + u_conns = jnp.zeros((N, N), dtype=jnp.float32) + + in_keys = jnp.asarray(conns[:, 0], jnp.int32) + out_keys = jnp.asarray(conns[:, 1], jnp.int32) + weights = conns[:, 2] + + u_conns = u_conns.at[in_keys, out_keys].set(weights) + return nodes, u_conns + + @staticmethod + def create_forward(config): + act = Activation.name2func[config['h_activation']] + agg = Aggregation.name2func[config['h_aggregation']] + + batch_act, batch_agg = vmap(act), vmap(agg) + + def forward(inputs, transform): + + inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) + nodes, weights = transform + + input_idx = config['h_input_idx'] + output_idx = config['h_output_idx'] + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + def body_func(i, values): + values = values.at[input_idx].set(inputs_with_bias) + nodes_ins = values * weights.T + values = batch_agg(nodes_ins) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(values) # z = act(z) + return values + + vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals) + return vals[output_idx] + + return forward diff --git a/algorithm/hyperneat/substrate/__init__.py b/algorithm/hyperneat/substrate/__init__.py new file mode 100644 index 0000000..366d01b --- /dev/null +++ b/algorithm/hyperneat/substrate/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseSubstrate +from .tools import analysis_substrate diff --git a/algorithm/hyperneat/substrate/base.py b/algorithm/hyperneat/substrate/base.py new file mode 100644 index 0000000..586ce20 --- /dev/null +++ b/algorithm/hyperneat/substrate/base.py @@ -0,0 +1,12 @@ +import numpy as np + + +class BaseSubstrate: + + @staticmethod + def setup(state, config): + return state.update( + input_coors=np.asarray(config['input_coors'], dtype=np.float32), + output_coors=np.asarray(config['output_coors'], dtype=np.float32), + hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32), + ) diff --git a/algorithm/hyperneat/substrate/tools.py b/algorithm/hyperneat/substrate/tools.py new file mode 100644 index 0000000..9eb4720 --- /dev/null +++ b/algorithm/hyperneat/substrate/tools.py @@ -0,0 +1,53 @@ +from typing import Type + +import numpy as np + +from .base import BaseSubstrate + + +def analysis_substrate(state): + cd = state.input_coors.shape[1] # coordinate dimensions + si = state.input_coors.shape[0] # input coordinate size + so = state.output_coors.shape[0] # output coordinate size + sh = state.hidden_coors.shape[0] # hidden coordinate size + + input_idx = np.arange(si) + output_idx = np.arange(si, si + so) + hidden_idx = np.arange(si + so, si + so + sh) + + total_conns = si * sh + sh * sh + sh * so + query_coors = np.zeros((total_conns, cd * 2)) + correspond_keys = np.zeros((total_conns, 2)) + + # connect input to hidden + aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) + query_coors[0: si * sh, :] = aux_coors + correspond_keys[0: si * sh, :] = aux_keys + + # connect hidden to hidden + aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) + query_coors[si * sh: si * sh + sh * sh, :] = aux_coors + correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys + + # connect hidden to output + aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) + query_coors[si * sh + sh * sh:, :] = aux_coors + correspond_keys[si * sh + sh * sh:, :] = aux_keys + + return input_idx, output_idx, hidden_idx, query_coors, correspond_keys + + +def cartesian_product(keys1, keys2, coors1, coors2): + len1 = keys1.shape[0] + len2 = keys2.shape[0] + + repeated_coors1 = np.repeat(coors1, len2, axis=0) + repeated_keys1 = np.repeat(keys1, len2) + + tiled_coors2 = np.tile(coors2, (len1, 1)) + tiled_keys2 = np.tile(keys2, len1) + + new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) + correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) + + return new_coors, correspond_keys diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index bff8b1c..dc30798 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,3 +1,2 @@ from .neat import NEAT -from .gene import NormalGene, RecurrentGene -from .pipeline import Pipeline +from .gene import BaseGene, NormalGene, RecurrentGene diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py index 9036e65..4f2e43d 100644 --- a/algorithm/neat/gene/base.py +++ b/algorithm/neat/gene/base.py @@ -33,12 +33,10 @@ class BaseGene: def distance_conn(state, conn1: Array, conn2: Array): return conn1 - @staticmethod - def forward_transform(nodes, conns): + def forward_transform(state, nodes, conns): return nodes, conns - @staticmethod def create_forward(config): - return None \ No newline at end of file + return None diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index d7fda5f..0cc1c80 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -4,7 +4,7 @@ from jax import Array, numpy as jnp from .base import BaseGene from .activation import Activation from .aggregation import Aggregation -from ..utils import unflatten_connections, I_INT +from algorithm.utils import unflatten_connections, I_INT from ..genome import topological_sort @@ -84,7 +84,7 @@ class NormalGene(BaseGene): return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight @staticmethod - def forward_transform(nodes, conns): + def forward_transform(state, nodes, conns): u_conns = unflatten_connections(nodes, conns) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py index 9d73b96..f723748 100644 --- a/algorithm/neat/gene/recurrent.py +++ b/algorithm/neat/gene/recurrent.py @@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap from .normal import NormalGene from .activation import Activation from .aggregation import Aggregation -from ..utils import unflatten_connections, I_INT +from algorithm.utils import unflatten_connections class RecurrentGene(NormalGene): @staticmethod - def forward_transform(nodes, conns): + def forward_transform(state, nodes, conns): u_conns = unflatten_connections(nodes, conns) # remove un-enable connections and remove enable attr diff --git a/algorithm/neat/genome/basic.py b/algorithm/neat/genome/basic.py index 76b7022..5635280 100644 --- a/algorithm/neat/genome/basic.py +++ b/algorithm/neat/genome/basic.py @@ -6,7 +6,7 @@ from jax import Array, numpy as jnp from algorithm import State from ..gene import BaseGene -from ..utils import fetch_first +from algorithm.utils import fetch_first def initialize_genomes(state: State, gene_type: Type[BaseGene]): @@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array): cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) return node_cnt, cons_cnt + def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]: """ Add a new node to the genome. diff --git a/algorithm/neat/genome/graph.py b/algorithm/neat/genome/graph.py index b813666..72da89f 100644 --- a/algorithm/neat/genome/graph.py +++ b/algorithm/neat/genome/graph.py @@ -6,7 +6,7 @@ Only used in feed-forward networks. import jax from jax import jit, Array, numpy as jnp -from ..utils import fetch_first, I_INT +from algorithm.utils import fetch_first, I_INT @jit diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py index 50db98a..47849fd 100644 --- a/algorithm/neat/genome/mutate.py +++ b/algorithm/neat/genome/mutate.py @@ -4,9 +4,9 @@ import jax from jax import Array, numpy as jnp, vmap from algorithm import State -from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count +from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx from .graph import check_cycles -from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections +from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections from ..gene import BaseGene diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 6ad078a..1f15745 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -3,22 +3,25 @@ from typing import Type import jax import jax.numpy as jnp -from algorithm.state import State +from algorithm import Algorithm, State from .gene import BaseGene from .genome import initialize_genomes from .population import create_tell -class NEAT: +class NEAT(Algorithm): def __init__(self, config, gene_type: Type[BaseGene]): + super().__init__() self.config = config self.gene_type = gene_type - self.tell_func = jax.jit(create_tell(config, self.gene_type)) + self.tell = create_tell(config, self.gene_type) + self.ask = None + self.forward = self.gene_type.create_forward(config) + self.forward_transform = self.gene_type.forward_transform - def setup(self, randkey): - - state = State( + def setup(self, randkey, state=State()): + state = state.update( P=self.config['pop_size'], N=self.config['maximum_nodes'], C=self.config['maximum_conns'], @@ -69,7 +72,4 @@ class NEAT: # move to device state = jax.device_put(state) - return state - - def step(self, state, fitness): - return self.tell_func(state, fitness) + return state \ No newline at end of file diff --git a/algorithm/neat/population.py b/algorithm/neat/population.py index 462e44f..a89d178 100644 --- a/algorithm/neat/population.py +++ b/algorithm/neat/population.py @@ -3,7 +3,7 @@ from typing import Type import jax from jax import numpy as jnp, vmap -from .utils import rank_elements, fetch_first +from algorithm.utils import rank_elements, fetch_first from .genome import create_mutate, create_distance, crossover from .gene import BaseGene diff --git a/algorithm/neat/utils.py b/algorithm/utils.py similarity index 100% rename from algorithm/neat/utils.py rename to algorithm/utils.py diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..dfb91b6 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1 @@ +from .config import Configer diff --git a/algorithm/config.py b/config/config.py similarity index 96% rename from algorithm/config.py rename to config/config.py index d23db46..3ee9c7d 100644 --- a/algorithm/config.py +++ b/config/config.py @@ -4,6 +4,7 @@ import configparser import numpy as np + class Configer: @classmethod @@ -28,7 +29,7 @@ class Configer: def __check_redundant_config(cls, default_config, config): for key in config: if key not in default_config: - warnings.warn(f"Redundant config: {key} in {config.name}") + warnings.warn(f"Redundant config: {key} in config!") @classmethod def __complete_config(cls, default_config, config): diff --git a/algorithm/default_config.ini b/config/default_config.ini similarity index 72% rename from algorithm/default_config.ini rename to config/default_config.ini index 97e2803..e8be118 100644 --- a/algorithm/default_config.ini +++ b/config/default_config.ini @@ -1,26 +1,38 @@ [basic] +random_seed = 0 +generation_limit = 1000 + +[problem] +fitness_threshold = 3.9999 num_inputs = 2 num_outputs = 1 -maximum_nodes = 50 -maximum_conns = 100 -maximum_species = 10 -forward_way = "pop" -batch_size = 4 -random_seed = 0 + +[neat] network_type = "feedforward" -activate_times = 10 +activate_times = 5 +maximum_nodes = 50 +maximum_conns = 50 +maximum_species = 10 + +[hyperneat] +below_threshold = 0.2 +max_weight = 3 +h_activation = "sigmoid" +h_aggregation = "sum" +h_activate_times = 5 + +[substrate] +input_coors = [[-1, 1], [0, 1], [1, 1]] +hidden_coors = [[-1, 0], [0, 0], [1, 0]] +output_coors = [[0, -1]] [population] -fitness_threshold = 3.9999 -generation_limit = 1000 -fitness_criterion = "max" -pop_size = 50000 +pop_size = 10 [genome] compatibility_disjoint = 1.0 compatibility_weight = 0.5 conn_add_prob = 0.4 -conn_add_trials = 1 conn_delete_prob = 0 node_add_prob = 0.2 node_delete_prob = 0 @@ -34,39 +46,37 @@ survival_threshold = 0.2 min_species_size = 1 spawn_number_change_rate = 0.5 -[gene-bias] +[gene] +# bias bias_init_mean = 0.0 bias_init_std = 1.0 bias_mutate_power = 0.5 bias_mutate_rate = 0.7 bias_replace_rate = 0.1 -[gene-response] +# response response_init_mean = 1.0 response_init_std = 0.0 response_mutate_power = 0.0 response_mutate_rate = 0.0 response_replace_rate = 0.0 -[gene-activation] +# activation activation_default = "sigmoid" -activation_option_names = ["sigmoid"] +activation_option_names = ["tanh"] activation_replace_rate = 0.0 -[gene-aggregation] +# aggregation aggregation_default = "sum" aggregation_option_names = ["sum"] aggregation_replace_rate = 0.0 -[gene-weight] +# weight weight_init_mean = 0.0 weight_init_std = 1.0 weight_mutate_power = 0.5 weight_mutate_rate = 0.8 weight_replace_rate = 0.1 -[gene-enable] -enable_mutate_rate = 0.01 - [visualize] renumber_nodes = True \ No newline at end of file diff --git a/examples/a.py b/examples/a.py new file mode 100644 index 0000000..6ffb01d --- /dev/null +++ b/examples/a.py @@ -0,0 +1,11 @@ +import numpy as np +import jax.numpy as jnp + +a = jnp.zeros((5, 5)) +k1 = jnp.array([1, 2, 3]) +k2 = jnp.array([2, 3, 4]) +v = jnp.array([1, 1, 1]) + +a = a.at[k1, k2].set(v) + +print(a) diff --git a/examples/rnn_forward_test.py b/examples/rnn_forward_test.py index 0351e1b..0d33f77 100644 --- a/examples/rnn_forward_test.py +++ b/examples/rnn_forward_test.py @@ -1,13 +1,44 @@ -import numpy as np +import jax +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class -vals = np.array([1, 2]) -weights = np.array([[0, 4], [5, 0]]) +@register_pytree_node_class +class Genome: + def __init__(self, nodes, conns): + self.nodes = nodes + self.conns = conns -ins1 = vals * weights[:, 0] -ins2 = vals * weights[:, 1] -ins_all = vals * weights.T + def update_nodes(self, nodes): + return Genome(nodes, self.conns) -print(ins1) -print(ins2) -print(ins_all) \ No newline at end of file + def update_conns(self, conns): + return Genome(self.nodes, conns) + + def tree_flatten(self): + children = self.nodes, self.conns + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + + def __repr__(self): + return f"Genome ({self.nodes}, \n\t{self.conns})" + + @jax.jit + def add_node(self, a: int): + nodes = self.nodes.at[0, :].set(a) + return self.update_nodes(nodes) + + +nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]]) +g = Genome(nodes, conns) +print(g) + +g = g.add_node(1) +print(g) + +g = jax.jit(g.add_node)(2) +print(g) diff --git a/examples/state_test.py b/examples/state_test.py deleted file mode 100644 index ef2fddf..0000000 --- a/examples/state_test.py +++ /dev/null @@ -1,15 +0,0 @@ -import jax -from jax import numpy as jnp -from algorithm.state import State - - -@jax.jit -def func(state: State, a): - return state.update(a=a) - - -state = State(c=1, b=2) -print(state) - -vmap_func = jax.vmap(func, in_axes=(None, 0)) -print(vmap_func(state, jnp.array([1, 2, 3]))) \ No newline at end of file diff --git a/examples/xor.ini b/examples/xor.ini index af2d8b9..9752677 100644 --- a/examples/xor.ini +++ b/examples/xor.ini @@ -1,7 +1,12 @@ [basic] -forward_way = "common" -network_type = "recurrent" activate_times = 5 +fitness_threshold = 4 [population] -fitness_threshold = 4 \ No newline at end of file +pop_size = 1000 + +[neat] +network_type = "recurrent" +num_inputs = 4 +num_outputs = 1 + diff --git a/examples/xor.py b/examples/xor.py index f3a3f67..73c7228 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,8 +1,10 @@ import jax import numpy as np -from algorithm import Configer, NEAT -from algorithm.neat import NormalGene, RecurrentGene, Pipeline +from pipeline import Pipeline +from config import Configer +from algorithm import NEAT +from algorithm.neat import RecurrentGene xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -21,7 +23,6 @@ def evaluate(forward_func): def main(): config = Configer.load_config("xor.ini") - # algorithm = NEAT(config, NormalGene) algorithm = NEAT(config, RecurrentGene) pipeline = Pipeline(config, algorithm) best = pipeline.auto_run(evaluate) diff --git a/examples/xor_hyperneat.py b/examples/xor_hyperneat.py new file mode 100644 index 0000000..d0d70d6 --- /dev/null +++ b/examples/xor_hyperneat.py @@ -0,0 +1,33 @@ +import jax +import numpy as np + +from pipeline import Pipeline +from config import Configer +from algorithm import NEAT, HyperNEAT +from algorithm.neat import RecurrentGene +from algorithm.hyperneat import BaseSubstrate + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +def main(): + config = Configer.load_config("xor.ini") + algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate) + pipeline = Pipeline(config, algorithm) + pipeline.auto_run(evaluate) + + +if __name__ == '__main__': + main() diff --git a/examples/xor_test.py b/examples/xor_test.py deleted file mode 100644 index 33da7bf..0000000 --- a/examples/xor_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import jax -import numpy as np - -from algorithm.config import Configer -from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline -from algorithm.neat.genome import create_mutate - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) - - -def single_genome(func, nodes, conns): - t = RecurrentGene.forward_transform(nodes, conns) - out1 = func(xor_inputs[0], t) - out2 = func(xor_inputs[1], t) - out3 = func(xor_inputs[2], t) - out4 = func(xor_inputs[3], t) - print(out1, out2, out3, out4) - - -def batch_genome(func, nodes, conns): - t = NormalGene.forward_transform(nodes, conns) - out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t) - print(out) - - -def pop_batch_genome(func, pop_nodes, pop_conns): - t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns) - func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0)) - out = func(xor_inputs, t) - print(out) - - -if __name__ == '__main__': - config = Configer.load_config("xor.ini") - # neat = NEAT(config, NormalGene) - neat = NEAT(config, RecurrentGene) - randkey = jax.random.PRNGKey(42) - state = neat.setup(randkey) - forward_func = RecurrentGene.create_forward(config) - mutate_func = create_mutate(config, RecurrentGene) - - nodes, conns = state.pop_nodes[0], state.pop_conns[0] - single_genome(forward_func, nodes, conns) - # batch_genome(forward_func, nodes, conns) - - nodes, conns = mutate_func(state, randkey, nodes, conns, 10000) - single_genome(forward_func, nodes, conns) - - # batch_genome(forward_func, nodes, conns) - # diff --git a/algorithm/neat/pipeline.py b/pipeline.py similarity index 84% rename from algorithm/neat/pipeline.py rename to pipeline.py index 612b5e9..da57566 100644 --- a/algorithm/neat/pipeline.py +++ b/pipeline.py @@ -5,15 +5,18 @@ import jax from jax import vmap, jit import numpy as np +from algorithm import Algorithm + class Pipeline: """ Neat algorithm pipeline. """ - def __init__(self, config, algorithm): + def __init__(self, config, algorithm: Algorithm): self.config = config self.algorithm = algorithm + randkey = jax.random.PRNGKey(config['random_seed']) self.state = algorithm.setup(randkey) @@ -23,18 +26,18 @@ class Pipeline: self.evaluate_time = 0 - self.forward_func = algorithm.gene_type.create_forward(config) + self.forward_func = jit(self.algorithm.forward) self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) - - self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform)) + self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0))) + self.tell_func = jit(self.algorithm.tell) def ask(self): - pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns) + pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns) return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) def tell(self, fitness): - self.state = self.algorithm.step(self.state, fitness) + self.state = self.tell_func(self.state, fitness) def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']): diff --git a/test/unit/test_cartesian_product.py b/test/unit/test_cartesian_product.py new file mode 100644 index 0000000..488eea0 --- /dev/null +++ b/test/unit/test_cartesian_product.py @@ -0,0 +1,56 @@ +import numpy as np + +from algorithm.hyperneat.substrate.tools import cartesian_product + + +def test01(): + keys1 = np.array([1, 2, 3]) + keys2 = np.array([4, 5, 6, 7]) + + coors1 = np.array([ + [1, 1, 1], + [2, 2, 2], + [3, 3, 3] + ]) + + coors2 = np.array([ + [4, 4, 4], + [5, 5, 5], + [6, 6, 6], + [7, 7, 7] + ]) + + target_coors = np.array([ + [1, 1, 1, 4, 4, 4], + [1, 1, 1, 5, 5, 5], + [1, 1, 1, 6, 6, 6], + [1, 1, 1, 7, 7, 7], + [2, 2, 2, 4, 4, 4], + [2, 2, 2, 5, 5, 5], + [2, 2, 2, 6, 6, 6], + [2, 2, 2, 7, 7, 7], + [3, 3, 3, 4, 4, 4], + [3, 3, 3, 5, 5, 5], + [3, 3, 3, 6, 6, 6], + [3, 3, 3, 7, 7, 7] + ]) + + target_keys = np.array([ + [1, 4], + [1, 5], + [1, 6], + [1, 7], + [2, 4], + [2, 5], + [2, 6], + [2, 7], + [3, 4], + [3, 5], + [3, 6], + [3, 7] + ]) + + new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2) + + assert np.array_equal(new_coors, target_coors) + assert np.array_equal(correspond_keys, target_keys) diff --git a/test/unit/test_graphs.py b/test/unit/test_graphs.py index 0b2ff17..5721d6e 100644 --- a/test/unit/test_graphs.py +++ b/test/unit/test_graphs.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from algorithm.neat.genome.graph import topological_sort, check_cycles -from algorithm.neat.utils import I_INT +from algorithm.utils import I_INT nodes = jnp.array([ [0], diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index d313402..81cb6c5 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -1,5 +1,5 @@ import jax.numpy as jnp -from algorithm.neat.utils import unflatten_connections +from algorithm.utils import unflatten_connections def test_unflatten():