diff --git a/algorithms/neat/function_factory.py b/algorithms/neat/function_factory.py index 70c0b71..f68194d 100644 --- a/algorithms/neat/function_factory.py +++ b/algorithms/neat/function_factory.py @@ -8,7 +8,7 @@ import numpy as np from jax import jit, vmap from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover -from .genome import topological_sort, forward_single +from .genome import topological_sort, forward_single, unflatten_connections class FunctionFactory: @@ -17,19 +17,18 @@ class FunctionFactory: self.debug = debug self.init_N = config.basic.init_maximum_nodes + self.init_C = config.basic.init_maximum_connections self.expand_coe = config.basic.expands_coe self.precompile_times = config.basic.pre_compile_times self.compiled_function = {} self.load_config_vals(config) self.precompile() - pass def load_config_vals(self, config): self.problem_batch = config.basic.problem_batch self.pop_size = config.neat.population.pop_size - self.init_N = config.basic.init_maximum_nodes self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient @@ -85,6 +84,7 @@ class FunctionFactory: initialize_genomes, pop_size=self.pop_size, N=self.init_N, + C=self.init_C, num_inputs=self.num_inputs, num_outputs=self.num_outputs, default_bias=self.bias_mean, @@ -107,24 +107,24 @@ class FunctionFactory: self.create_crossover_with_args() self.create_topological_sort_with_args() self.create_single_forward_with_args() - - n = self.init_N - print("start precompile") - for _ in range(self.precompile_times): - self.compile_mutate(n) - self.compile_distance(n) - self.compile_crossover(n) - self.compile_topological_sort_batch(n) - self.compile_pop_batch_forward(n) - n = int(self.expand_coe * n) - - # precompile other functions used in jax - key = jax.random.PRNGKey(0) - _ = jax.random.split(key, 3) - _ = jax.random.split(key, self.pop_size * 2) - _ = jax.random.split(key, self.pop_size) - - print("end precompile") + # + # n, c = self.init_N, self.init_C + # print("start precompile") + # for _ in range(self.precompile_times): + # self.compile_mutate(n) + # self.compile_distance(n) + # self.compile_crossover(n) + # self.compile_topological_sort_batch(n) + # self.compile_pop_batch_forward(n) + # n = int(self.expand_coe * n) + # + # # precompile other functions used in jax + # key = jax.random.PRNGKey(0) + # _ = jax.random.split(key, 3) + # _ = jax.random.split(key, self.pop_size * 2) + # _ = jax.random.split(key, self.pop_size) + # + # print("end precompile") def create_mutate_with_args(self): func = partial( @@ -161,20 +161,20 @@ class FunctionFactory: ) self.mutate_with_args = func - def compile_mutate(self, n): + def compile_mutate(self, n, c): func = self.mutate_with_args rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) nodes_lower = np.zeros((self.pop_size, n, 5)) - connections_lower = np.zeros((self.pop_size, 2, n, n)) + connections_lower = np.zeros((self.pop_size, c, 4)) new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32) batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile() - self.compiled_function[('mutate', n)] = batched_mutate_func + self.compiled_function[('mutate', n, c)] = batched_mutate_func - def create_mutate(self, n): - key = ('mutate', n) + def create_mutate(self, n, c): + key = ('mutate', n, c) if key not in self.compiled_function: - self.compile_mutate(n) + self.compile_mutate(n, c) if self.debug: def debug_mutate(*args): res_nodes, res_connections = self.compiled_function[key](*args) @@ -192,28 +192,28 @@ class FunctionFactory: ) self.distance_with_args = func - def compile_distance(self, n): + def compile_distance(self, n, c): func = self.distance_with_args o2o_nodes1_lower = np.zeros((n, 5)) - o2o_connections1_lower = np.zeros((2, n, n)) + o2o_connections1_lower = np.zeros((c, 4)) o2o_nodes2_lower = np.zeros((n, 5)) - o2o_connections2_lower = np.zeros((2, n, n)) + o2o_connections2_lower = np.zeros((c, 4)) o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, o2o_nodes2_lower, o2o_connections2_lower).compile() o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) - o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n)) + o2m_connections2_lower = np.zeros((self.pop_size, c, 4)) o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, o2m_nodes2_lower, o2m_connections2_lower).compile() - self.compiled_function[('o2o_distance', n)] = o2o_distance - self.compiled_function[('o2m_distance', n)] = o2m_distance + self.compiled_function[('o2o_distance', n, c)] = o2o_distance + self.compiled_function[('o2m_distance', n, c)] = o2m_distance - def create_distance(self, n): - key1, key2 = ('o2o_distance', n), ('o2m_distance', n) + def create_distance(self, n, c): + key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c) if key1 not in self.compiled_function: - self.compile_distance(n) + self.compile_distance(n, c) if self.debug: def debug_o2o_distance(*args): @@ -229,21 +229,21 @@ class FunctionFactory: def create_crossover_with_args(self): self.crossover_with_args = crossover - def compile_crossover(self, n): + def compile_crossover(self, n, c): func = self.crossover_with_args randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) nodes1_lower = np.zeros((self.pop_size, n, 5)) - connections1_lower = np.zeros((self.pop_size, 2, n, n)) + connections1_lower = np.zeros((self.pop_size, c, 4)) nodes2_lower = np.zeros((self.pop_size, n, 5)) - connections2_lower = np.zeros((self.pop_size, 2, n, n)) + connections2_lower = np.zeros((self.pop_size, c, 4)) func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() - self.compiled_function[('crossover', n)] = func + self.compiled_function[('crossover', n, c)] = func - def create_crossover(self, n): - key = ('crossover', n) + def create_crossover(self, n, c): + key = ('crossover', n, c) if key not in self.compiled_function: - self.compile_crossover(n) + self.compile_crossover(n, c) if self.debug: def debug_crossover(*args): @@ -365,15 +365,17 @@ class FunctionFactory: else: return self.compiled_function[key] - def ask_pop_batch_forward(self, pop_nodes, pop_connections): - n = pop_nodes.shape[1] + def ask_pop_batch_forward(self, pop_nodes, pop_cons): + n, c = pop_nodes.shape[1], pop_cons.shape[1] + batch_unflatten_func = self.create_batch_unflatten_connections(n, c) + pop_cons = batch_unflatten_func(pop_nodes, pop_cons) ts = self.create_topological_sort_batch(n) - pop_cal_seqs = ts(pop_nodes, pop_connections) + pop_cal_seqs = ts(pop_nodes, pop_cons) forward_func = self.create_pop_batch_forward(n) def debug_forward(inputs): - return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections) + return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons) return debug_forward @@ -387,3 +389,23 @@ class FunctionFactory: return forward_func(inputs, cal_seqs, nodes, connections) return debug_forward + + def compile_batch_unflatten_connections(self, n, c): + func = unflatten_connections + func = vmap(func) + pop_nodes_lower = np.zeros((self.pop_size, n, 5)) + pop_connections_lower = np.zeros((self.pop_size, c, 4)) + func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile() + self.compiled_function[('batch_unflatten_connections', n, c)] = func + + def create_batch_unflatten_connections(self, n, c): + key = ('batch_unflatten_connections', n, c) + if key not in self.compiled_function: + self.compile_batch_unflatten_connections(n, c) + if self.debug: + def debug_batch_unflatten_connections(*args): + return self.compiled_function[key](*args).block_until_ready() + + return debug_batch_unflatten_connections + else: + return self.compiled_function[key] diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index 1bc92c6..d723d36 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,8 +1,9 @@ -from .genome import expand, expand_single, pop_analysis, initialize_genomes -from .forward import create_forward_function, forward_single +from .genome import expand, expand_single, initialize_genomes +from .forward import forward_single from .activations import act_name2key from .aggregations import agg_name2key from .crossover import crossover from .mutate import mutate from .distance import distance from .graph import topological_sort +from .utils import unflatten_connections \ No newline at end of file diff --git a/algorithms/neat/genome/activations.py b/algorithms/neat/genome/activations.py index 89e0f6a..db30a78 100644 --- a/algorithms/neat/genome/activations.py +++ b/algorithms/neat/genome/activations.py @@ -23,8 +23,8 @@ def sin_act(z): @jit def gauss_act(z): - z = jnp.clip(z, -3.4, 3.4) - return jnp.exp(-5 * z ** 2) + z = jnp.clip(z * 5, -3.4, 3.4) + return jnp.exp(-z ** 2) @jit diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index b4e78fb..0873b98 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -7,16 +7,16 @@ from jax import numpy as jnp @jit -def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ +def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \ -> Tuple[Array, Array]: """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) :param randkey: :param nodes1: - :param connections1: + :param cons1: :param nodes2: - :param connections2: + :param cons2: :return: """ randkey_1, randkey_2 = jax.random.split(randkey) @@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections - con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2] - connections2 = align_array(con_keys1, con_keys2, connections2, 'connection') - new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2)) - new_cons = unflatten_connections(len(keys1), new_cons) + con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] + cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') + new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) return new_nodes, new_cons -@partial(jit, static_argnames=['gene_type']) +# @partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ After I review this code, I found that it is the most difficult part of the code. Please never change it! @@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: return refactor_ar2 -@jit +# @jit def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: """ crossover two genes diff --git a/algorithms/neat/genome/debug/__init__.py b/algorithms/neat/genome/debug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithms/neat/genome/debug/tools.py b/algorithms/neat/genome/debug/tools.py new file mode 100644 index 0000000..34de771 --- /dev/null +++ b/algorithms/neat/genome/debug/tools.py @@ -0,0 +1,88 @@ +from collections import defaultdict + +import numpy as np + + +def check_array_valid(nodes, cons, input_keys, output_keys): + nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys) + # assert is_DAG(cons_dict.keys()), "The genome is not a DAG!" + + +def array2object(nodes, cons, input_keys, output_keys): + """ + Convert a genome from array to dict. + :param nodes: (N, 5) + :param cons: (C, 4) + :param output_keys: + :param input_keys: + :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] + """ + # update nodes_dict + nodes_dict = {} + for i, node in enumerate(nodes): + if np.isnan(node[0]): + continue + key = int(node[0]) + assert key not in nodes_dict, f"Duplicate node key: {key}!" + + if key in input_keys: + assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!" + nodes_dict[key] = (None,) * 4 + else: + assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!" + bias = node[1] + response = node[2] + act = node[3] + agg = node[4] + nodes_dict[key] = (bias, response, act, agg) + + # check nodes_dict + for i in input_keys: + assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" + + for o in output_keys: + assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" + + # update connections + cons_dict = {} + for i, con in enumerate(cons): + if np.all(np.isnan(con)): + pass + elif np.all(~np.isnan(con)): + i_key = int(con[0]) + o_key = int(con[1]) + if (i_key, o_key) in cons_dict: + assert False, f"Duplicate connection: {(i_key, o_key)}!" + assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" + assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" + weight = con[2] + enabled = (con[3] == 1) + cons_dict[(i_key, o_key)] = (weight, enabled) + else: + assert False, f"Connection {i} must has all None or all non-None!" + + return nodes_dict, cons_dict + + +def is_DAG(edges): + all_nodes = set() + for a, b in edges: + if a == b: # cycle + return False + all_nodes.union({a, b}) + + for node in all_nodes: + visited = {n: False for n in all_nodes} + def dfs(n): + if visited[n]: + return False + visited[n] = True + for a, b in edges: + if a == n: + if not dfs(b): + return False + return True + + if not dfs(node): + return False + return True diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 8a5db3e..e314d77 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -1,11 +1,11 @@ from jax import jit, vmap, Array from jax import numpy as jnp -from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON +from .utils import EMPTY_NODE, EMPTY_CON @jit -def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., +def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1., compatibility_coe: float = 0.5) -> Array: """ Calculate the distance between two genomes. @@ -15,10 +15,6 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance - # refactor connections - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] - cons1 = flatten_connections(keys1, connections1) - cons2 = flatten_connections(keys2, connections2) cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance return nd + cd @@ -35,9 +31,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): nodes = nodes[sorted_indices] nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end fr, sr = nodes[:-1], nodes[1:] # first row, second row - nan_mask = jnp.isnan(nodes[:, 0]) - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1] + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) nd = batch_homologous_node_distance(fr, sr) @@ -50,8 +45,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): @jit def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): - con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists - con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2])) + con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) + con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) max_cnt = jnp.maximum(con_cnt1, con_cnt2) cons = jnp.concatenate((cons1, cons2), axis=0) @@ -62,7 +57,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): fr, sr = cons[:-1], cons[1:] # first row, second row # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2]) + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) cd = batch_homologous_connection_distance(fr, sr) diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 9aa22c4..07c1dca 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -1,51 +1,12 @@ -from functools import partial - import jax from jax import Array, numpy as jnp from jax import jit, vmap -from numpy.typing import NDArray from .aggregations import agg from .activations import act -from .graph import topological_sort, batch_topological_sort from .utils import I_INT - -def create_forward_function(nodes: NDArray, connections: NDArray, - N: int, input_idx: NDArray, output_idx: NDArray, batch: bool): - """ - create forward function for different situations - - :param nodes: shape (N, 5) or (pop_size, N, 5) - :param connections: shape (2, N, N) or (pop_size, 2, N, N) - :param N: - :param input_idx: - :param output_idx: - :param batch: using batch or not - :param debug: debug mode - :return: - """ - - if nodes.ndim == 2: # single genome - cal_seqs = topological_sort(nodes, connections) - if not batch: - return lambda inputs: forward_single(inputs, N, input_idx, output_idx, - cal_seqs, nodes, connections) - else: - return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx, - cal_seqs, nodes, connections) - elif nodes.ndim == 3: # pop genome - pop_cal_seqs = batch_topological_sort(nodes, connections) - if not batch: - return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx, - pop_cal_seqs, nodes, connections) - else: - return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx, - pop_cal_seqs, nodes, connections) - else: - raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") - - +# TODO: enabled information doesn't influence forward. That is wrong! @jit def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array, input_idx: Array, output_idx: Array) -> Array: @@ -84,66 +45,3 @@ def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Ar vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) return vals[output_idx] - - -# @partial(jit, static_argnames=['N']) -# @partial(vmap, in_axes=(0, None, None, None, None, None, None)) -# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, -# cal_seqs: Array, nodes: Array, connections: Array) -> Array: -# """ -# jax forward for batch_inputs shaped (batch_size, input_num) -# nodes, connections are single genome -# -# :argument batch_inputs: (batch_size, input_num) -# :argument N: int -# :argument input_idx: (input_num, ) -# :argument output_idx: (output_num, ) -# :argument cal_seqs: (N, ) -# :argument nodes: (N, 5) -# :argument connections: (2, N, N) -# -# :return (batch_size, output_num) -# """ -# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) -# -# -# @partial(jit, static_argnames=['N']) -# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) -# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, -# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: -# """ -# jax forward for single input shaped (input_num, ) -# pop_nodes, pop_connections are population of genomes -# -# :argument inputs: (input_num, ) -# :argument N: int -# :argument input_idx: (input_num, ) -# :argument output_idx: (output_num, ) -# :argument pop_cal_seqs: (pop_size, N) -# :argument pop_nodes: (pop_size, N, 5) -# :argument pop_connections: (pop_size, 2, N, N) -# -# :return (pop_size, output_num) -# """ -# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) -# -# -# @partial(jit, static_argnames=['N']) -# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) -# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, -# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: -# """ -# jax forward for batch input shaped (batch, input_num) -# pop_nodes, pop_connections are population of genomes -# -# :argument batch_inputs: (batch_size, input_num) -# :argument N: int -# :argument input_idx: (input_num, ) -# :argument output_idx: (output_num, ) -# :argument pop_cal_seqs: (pop_size, N) -# :argument pop_nodes: (pop_size, N, 5) -# :argument pop_connections: (pop_size, 2, N, N) -# -# :return (pop_size, batch_size, output_num) -# """ -# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 74fea62..9162f7f 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -20,7 +20,7 @@ from jax import numpy as jnp from jax import jit from jax import Array -from .utils import fetch_first, EMPTY_NODE +from .utils import fetch_first def initialize_genomes(pop_size: int, @@ -124,79 +124,6 @@ def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tupl return new_nodes, new_cons -def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \ - Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]: - """ - Convert a genome from array to dict. - :param nodes: (N, 5) - :param cons: (C, 4) - :param output_keys: - :param input_keys: - :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] - """ - # update nodes_dict - try: - nodes_dict = {} - for i, node in enumerate(nodes): - if np.isnan(node[0]): - continue - key = int(node[0]) - assert key not in nodes_dict, f"Duplicate node key: {key}!" - - bias = node[1] if not np.isnan(node[1]) else None - response = node[2] if not np.isnan(node[2]) else None - act = node[3] if not np.isnan(node[3]) else None - agg = node[4] if not np.isnan(node[4]) else None - nodes_dict[key] = (bias, response, act, agg) - - # check nodes_dict - for i in input_keys: - assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" - bias, response, act, agg = nodes_dict[i] - assert bias is None and response is None and act is None and agg is None, \ - f"Input node {i} must has None bias, response, act, or agg!" - - for o in output_keys: - assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" - - for k, v in nodes_dict.items(): - if k not in input_keys: - bias, response, act, agg = v - assert bias is not None and response is not None and act is not None and agg is not None, \ - f"Normal node {k} must has non-None bias, response, act, or agg!" - - # update connections - cons_dict = {} - for i, con in enumerate(cons): - if np.isnan(con[0]): - continue - assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!" - i_key = int(con[0]) - o_key = int(con[1]) - assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" - assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" - key = (i_key, o_key) - weight = con[2] if not np.isnan(con[2]) else None - enabled = (con[3] == 1) if not np.isnan(con[3]) else None - assert weight is not None, f"Connection {key} must has non-None weight!" - assert enabled is not None, f"Connection {key} must has non-None enabled!" - - cons_dict[key] = (weight, enabled) - - return nodes_dict, cons_dict - except AssertionError: - print(nodes) - print(cons) - raise AssertionError - - -def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys): - res = [] - for nodes, cons in zip(pop_nodes, pop_cons): - res.append(analysis(nodes, cons, input_keys, output_keys)) - return res - - @jit def count(nodes, cons): node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) @@ -231,7 +158,7 @@ def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Arra """ use idx to delete a node from the genome. only delete the node, regardless of connections. """ - nodes = nodes.at[idx].set(EMPTY_NODE) + nodes = nodes.at[idx].set(np.nan) return nodes, cons @@ -243,7 +170,7 @@ def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int, """ con_keys = cons[:, 0] idx = fetch_first(jnp.isnan(con_keys)) - return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled) + return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled) @jit diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index ae669d9..dd155be 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -6,11 +6,9 @@ import numpy as np from jax import numpy as jnp from jax import jit, vmap, Array -from .utils import fetch_random, fetch_first, I_INT -from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx +from .utils import fetch_random, fetch_first, I_INT, unflatten_connections +from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection from .graph import check_cycles -from .activations import act_name2key -from .aggregations import agg_name2key @partial(jit, static_argnames=('single_structure_mutate',)) @@ -89,7 +87,7 @@ def mutate(rand_key: Array, return n, c def m_add_node(rk, n, c): - return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default) + return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default) def m_delete_node(rk, n, c): return mutate_delete_node(rk, n, c, input_idx, output_idx) @@ -153,7 +151,7 @@ def mutate(rand_key: Array, @jit def mutate_values(rand_key: Array, nodes: Array, - connections: Array, + cons: Array, bias_mean: float = 0, bias_std: float = 1, bias_mutate_strength: float = 0.5, @@ -180,7 +178,7 @@ def mutate_values(rand_key: Array, Args: rand_key: A random key for generating random values. nodes: A 2D array representing nodes. - connections: A 3D array representing connections. + cons: A 3D array representing connections. bias_mean: Mean of the bias values. bias_std: Standard deviation of the bias values. bias_mutate_strength: Strength of the bias mutation. @@ -211,24 +209,23 @@ def mutate_values(rand_key: Array, bias_mutate_strength, bias_mutate_rate, bias_replace_rate) response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std, response_mutate_strength, response_mutate_rate, response_replace_rate) - weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std, + weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, weight_replace_rate) act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate) agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate) - # refactor enabled - r = jax.random.uniform(rand_key, connections[1, :, :].shape) - enabled_new = connections[1, :, :] == 1 - enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new) - enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan) + # mutate enabled + r = jax.random.uniform(rand_key, cons[:, 3].shape) + enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3]) + enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan) nodes = nodes.at[:, 1].set(bias_new) nodes = nodes.at[:, 2].set(response_new) nodes = nodes.at[:, 3].set(act_new) nodes = nodes.at[:, 4].set(agg_new) - connections = connections.at[0, :, :].set(weight_new) - connections = connections.at[1, :, :].set(enabled_new) - return nodes, connections + cons = cons.at[:, 2].set(weight_new) + cons = cons.at[:, 3].set(enabled_new) + return nodes, cons @jit @@ -288,7 +285,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace @jit -def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, +def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int, default_bias: float = 0, default_response: float = 1, default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: """ @@ -296,7 +293,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection :param rand_key: :param new_node_key: :param nodes: - :param connections: + :param cons: :param default_bias: :param default_response: :param default_act: @@ -304,44 +301,42 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection :return: """ # randomly choose a connection - from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) + i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) - def nothing(): - return nodes, connections + def nothing(): # there is no connection to split + return nodes, cons def successful_add_node(): # disable the connection - new_nodes, new_connections = nodes, connections - new_connections = new_connections.at[1, from_idx, to_idx].set(False) + new_nodes, new_cons = nodes, cons + new_cons = new_cons.at[idx, 3].set(False) # add a new node - new_nodes, new_connections = \ - add_node(new_node_key, new_nodes, new_connections, + new_nodes, new_cons = \ + add_node(new_nodes, new_cons, new_node_key, bias=default_bias, response=default_response, act=default_act, agg=default_agg) - new_idx = fetch_first(new_nodes[:, 0] == new_node_key) # add two new connections - weight = new_connections[0, from_idx, to_idx] - new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, - new_nodes, new_connections, weight=1., enabled=True) - new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, - new_nodes, new_connections, weight=weight, enabled=True) - return new_nodes, new_connections + w = new_cons[idx, 2] + new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True) + new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True) + return new_nodes, new_cons # if from_idx == I_INT, that means no connection exist, do nothing - nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node) + nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node) - return nodes, connections + return nodes, cons +# TODO: Need we really need to delete a node? @jit -def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, +def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ Randomly delete a node. Input and output nodes are not allowed to be deleted. :param rand_key: :param nodes: - :param connections: + :param cons: :param input_keys: :param output_keys: :return: @@ -351,83 +346,86 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, allow_input_keys=False, allow_output_keys=False) def nothing(): - return nodes, connections + return nodes, cons def successful_delete_node(): # delete the node - aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) + aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx) - # delete connections - aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) - aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) + # delete all connections + aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis], + jnp.nan, aux_cons) - return aux_nodes, aux_connections + return aux_nodes, aux_cons - nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) + nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) - return nodes, connections + return nodes, cons @jit -def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, +def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, cycles are not allowed. :param rand_key: :param nodes: - :param connections: + :param cons: :param input_keys: :param output_keys: :return: """ # randomly choose two nodes k1, k2 = jax.random.split(rand_key, num=2) - from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, + i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, allow_input_keys=True, allow_output_keys=True) - to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, + o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, allow_input_keys=False, allow_output_keys=True) + con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + def successful(): - new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections) - return new_nodes, new_connections + new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True) + return new_nodes, new_cons def already_exist(): - new_connections = connections.at[1, from_idx, to_idx].set(True) - return nodes, new_connections + new_cons = cons.at[con_idx, 3].set(True) + return nodes, new_cons def cycle(): - return nodes, connections + return nodes, cons - is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx]) - is_cycle = check_cycles(nodes, connections, from_idx, to_idx) + is_already_exist = con_idx != I_INT + unflattened = unflatten_connections(nodes, cons) + is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful]) - return nodes, connections + nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful]) + return nodes, cons @jit -def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): +def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array): """ Randomly delete a connection. :param rand_key: :param nodes: - :param connections: + :param cons: :return: """ # randomly choose a connection - from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) + i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons) def nothing(): - return nodes, connections + return nodes, cons def successfully_delete_connection(): - return delete_connection_by_idx(from_idx, to_idx, nodes, connections) + return delete_connection_by_idx(nodes, cons, idx) - nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection) + nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - return nodes, connections + return nodes, cons @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) @@ -460,31 +458,20 @@ def choice_node_key(rand_key: Array, nodes: Array, @jit -def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: +def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: """ Randomly choose a connection key from the given connections. :param rand_key: :param nodes: - :param connection: - :return: from_key, to_key, from_idx, to_idx + :param cons: + :return: i_key, o_key, idx """ - k1, k2 = jax.random.split(rand_key, num=2) + idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) + i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) - has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1) - - def nothing(): - return jnp.nan, jnp.nan, I_INT, I_INT - - def has_connection(): - f_idx = fetch_random(k1, has_connections_row) - col = connection[0, f_idx, :] - t_idx = fetch_random(k2, ~jnp.isnan(col)) - f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0] - return f_key, t_key, f_idx, t_idx - - from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing) - return from_key, to_key, from_idx, to_idx + return i_key, o_key, idx @jit diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index 9a72536..19703e3 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -3,84 +3,38 @@ from typing import Tuple import jax from jax import numpy as jnp, Array -from jax import jit +from jax import jit, vmap I_INT = jnp.iinfo(jnp.int32).max # infinite int EMPTY_NODE = jnp.full((1, 5), jnp.nan) EMPTY_CON = jnp.full((1, 4), jnp.nan) + @jit -def flatten_connections(keys, connections): +def unflatten_connections(nodes, cons): """ - flatten the (2, N, N) connections to (N * N, 4) - :param keys: - :param connections: - :return: - the first two columns are the index of the node - the 3rd column is the weight, and the 4th column is the enabled status - """ - indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij') - indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2) - - # make (2, N, N) to (N, N, 2) - con = jnp.transpose(connections, (1, 2, 0)) - # make (N, N, 2) to (N * N, 2) - con = jnp.reshape(con, (-1, 2)) - - con = jnp.concatenate((indices, con), axis=1) - return con - - -@partial(jit, static_argnames=['N']) -def unflatten_connections(N, cons): - """ - restore the (N * N, 4) connections to (2, N, N) - :param N: + transform the (C, 4) connections to (2, N, N) :param cons: + :param nodes: :return: """ - cons = cons[:, 2:] # remove the indices - unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0) - return unflatten_cons + N = nodes.shape[0] + node_keys = nodes[:, 0] + i_keys, o_keys = cons[:, 0], cons[:, 1] + i_idxs = key_to_indices(i_keys, node_keys) + o_idxs = key_to_indices(o_keys, node_keys) + res = jnp.full((2, N, N), jnp.nan) + + # Is interesting that jax use clip when attach data in array + # however, it will do nothing set values in an array + res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) + res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) + return res -@jit -def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]: - """ - Analyze the intersection and union of two arrays by returning their sorted concatenation indices, - intersection mask, and union mask. - - :param ar1: JAX array of shape (N, M) - First input array. Should have the same shape as ar2. - :param ar2: JAX array of shape (N, M) - Second input array. Should have the same shape as ar1. - :return: tuple of 3 arrays - - sorted_indices: Indices that would sort the concatenation of ar1 and ar2. - - intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2 - in the sorted concatenation. - - union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2 - in the sorted concatenation. - - Examples: - a = jnp.array([[1, 2], [3, 4], [5, 6]]) - b = jnp.array([[1, 2], [7, 8], [9, 10]]) - - sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b) - - sorted_indices -> array([0, 1, 2, 3, 4, 5]) - intersect_mask -> array([True, False, False, False, False, False]) - union_mask -> array([False, True, True, True, True, True]) - """ - ar = jnp.concatenate((ar1, ar2), axis=0) - sorted_indices = jnp.lexsort(ar.T[::-1]) - aux = ar[sorted_indices] - aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0) - nan_mask = jnp.any(jnp.isnan(aux), axis=1) - - fr, sr = aux[:-1], aux[1:] # first row, second row - intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1] - union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1] - return sorted_indices, intersect_mask, union_mask +@partial(vmap, in_axes=(0, None)) +def key_to_indices(key, keys): + return fetch_first(key == keys) @jit diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 9396556..f199cb7 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -8,6 +8,7 @@ from .species import SpeciesController from .genome import expand, expand_single from .function_factory import FunctionFactory from .genome.genome import count +from .genome.debug.tools import check_array_valid class Pipeline: """ @@ -23,6 +24,7 @@ class Pipeline: self.config = config self.N = config.basic.init_maximum_nodes + self.C = config.basic.init_maximum_connections self.expand_coe = config.basic.expands_coe self.pop_size = config.neat.population.pop_size @@ -57,6 +59,8 @@ class Pipeline: self.update_next_generation(winner_part, loser_part, elite_mask) + # pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) + self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation, self.o2o_distance, self.o2m_distance) @@ -105,16 +109,25 @@ class Pipeline: npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + + # for i in range(self.pop_size): + # n, c = np.array(npn[i]), np.array(npc[i]) + # check_array_valid(n, c, self.input_idx, self.output_idx) + # mutate new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes + # for i in range(self.pop_size): + # n, c = np.array(m_npn[i]), np.array(m_npc[i]) + # check_array_valid(n, c, self.input_idx, self.output_idx) + # elitism don't mutate npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn) - self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc) + self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc) def expand(self): """ @@ -128,20 +141,38 @@ class Pipeline: max_node_size = np.max(pop_node_sizes) if max_node_size >= self.N: self.N = int(self.N * self.expand_coe) - print(f"expand to {self.N}!") - self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N) + print(f"node expand to {self.N}!") + self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) # don't forget to expand representation genome in species for s in self.species_controller.species.values(): - s.representative = expand_single(*s.representative, self.N) + s.representative = expand_single(*s.representative, self.N, self.C) # update functions self.compile_functions(debug=True) + + pop_con_keys = self.pop_connections[:, :, 0] + pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) + max_con_size = np.max(pop_node_sizes) + if max_con_size >= self.C: + self.C = int(self.C * self.expand_coe) + print(f"connections expand to {self.C}!") + self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) + + # don't forget to expand representation genome in species + for s in self.species_controller.species.values(): + s.representative = expand_single(*s.representative, self.N, self.C) + + # update functions + self.compile_functions(debug=True) + + + def compile_functions(self, debug=False): - self.mutate_func = self.function_factory.create_mutate(self.N) - self.crossover_func = self.function_factory.create_crossover(self.N) - self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N) + self.mutate_func = self.function_factory.create_mutate(self.N, self.C) + self.crossover_func = self.function_factory.create_crossover(self.N, self.C) + self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C) def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) diff --git a/examples/function_tests.py b/examples/function_tests.py new file mode 100644 index 0000000..8de0e3b --- /dev/null +++ b/examples/function_tests.py @@ -0,0 +1,49 @@ +import jax +import numpy as np +from algorithms.neat.function_factory import FunctionFactory +from algorithms.neat.genome.debug.tools import check_array_valid +from utils import Configer + +from algorithms.neat.genome.crossover import crossover + + +if __name__ == '__main__': + config = Configer.load_config() + function_factory = FunctionFactory(config, debug=True) + initialize_func = function_factory.create_initialize() + pop_nodes, pop_connections, input_idx, output_idx = initialize_func() + mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1]) + crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1]) + key = jax.random.PRNGKey(0) + new_node_idx = 100 + while True: + key, subkey = jax.random.split(key) + mutate_keys = jax.random.split(subkey, len(pop_nodes)) + new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes)) + new_node_idx += len(pop_nodes) + pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes) + pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections]) + # for i in range(len(pop_nodes)): + # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) + idx1 = np.random.permutation(len(pop_nodes)) + idx2 = np.random.permutation(len(pop_nodes)) + + n1, c1 = pop_nodes[idx1], pop_connections[idx1] + n2, c2 = pop_nodes[idx2], pop_connections[idx2] + crossover_keys = jax.random.split(subkey, len(pop_nodes)) + + # for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)): + # n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2) + # try: + # check_array_valid(n, c, input_idx, output_idx) + # except AssertionError as e: + # crossover(crossover_keys[idx], zn1, zc1, zn2, zc2) + + pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2) + + for i in range(len(pop_nodes)): + check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) + + print(new_node_idx) + + diff --git a/examples/jax_playground.py b/examples/jax_playground.py index f052efc..a3bbcbc 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -1,11 +1,3 @@ import numpy as np -# 输入 -a = np.array([1, 2, 3, 4]) -b = np.array([5, 6]) - -# 创建一个网格,其中包含所有可能的组合 -aa, bb = np.meshgrid(a, b) -aa = aa.flatten() -bb = bb.flatten() -print(aa, bb) \ No newline at end of file +print(np.random.permutation(10)) \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 8d629fb..2b65e8f 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -6,8 +6,8 @@ from time_utils import using_cprofile from problems import Sin, Xor, DIY -# @using_cprofile -@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +@using_cprofile +# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() problem = Xor() diff --git a/utils/default_config.json b/utils/default_config.json index 993b080..23359ba 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -4,6 +4,7 @@ "num_outputs": 1, "problem_batch": 4, "init_maximum_nodes": 10, + "init_maximum_connections": 10, "expands_coe": 2, "pre_compile_times": 3, "forward_way": "pop_batch" @@ -13,7 +14,7 @@ "fitness_criterion": "max", "fitness_threshold": -0.001, "generation_limit": 1000, - "pop_size": 1000, + "pop_size": 5000, "reset_on_extinction": "False" }, "gene": { @@ -57,9 +58,9 @@ "compatibility_weight_coefficient": 0.5, "single_structural_mutation": "False", "conn_add_prob": 0.5, - "conn_delete_prob": 0, + "conn_delete_prob": 0.5, "node_add_prob": 0.2, - "node_delete_prob": 0 + "node_delete_prob": 0.2 }, "species": { "compatibility_threshold": 2.5,