From a3b9bca8665d71ef828c247932e03376b9c75964 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 7 May 2023 16:03:52 +0800 Subject: [PATCH] =?UTF-8?q?bug=20down=EF=BC=81=20Here=20it=20can=20solve?= =?UTF-8?q?=20xor=20successfully!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- algorithms/neat/genome/__init__.py | 2 +- algorithms/neat/genome/crossover.py | 60 +----------------- algorithms/neat/genome/distance.py | 78 +++++++++++++----------- algorithms/neat/genome/forward.py | 33 ++-------- algorithms/neat/genome/genome.py | 3 +- algorithms/neat/genome/graph.py | 24 +------- algorithms/neat/genome/mutate.py | 44 +++++++------ algorithms/neat/genome/numpy/distance.py | 38 ------------ algorithms/neat/genome/utils.py | 3 +- algorithms/neat/pipeline.py | 58 +++++++++--------- algorithms/neat/species.py | 26 +++----- examples/xor.py | 5 +- 12 files changed, 120 insertions(+), 254 deletions(-) diff --git a/algorithms/neat/genome/__init__.py b/algorithms/neat/genome/__init__.py index bb905b5..232ac8c 100644 --- a/algorithms/neat/genome/__init__.py +++ b/algorithms/neat/genome/__init__.py @@ -1,4 +1,4 @@ -from .genome import create_initialize_function, expand, expand_single +from .genome import create_initialize_function, expand, expand_single, pop_analysis from .distance import distance from .mutate import create_mutate_function from .forward import create_forward_function diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 2560491..f147c58 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -5,8 +5,7 @@ import jax from jax import jit, vmap, Array from jax import numpy as jnp -# from .utils import flatten_connections, unflatten_connections -from algorithms.neat.genome.utils import flatten_connections, unflatten_connections +from .utils import flatten_connections, unflatten_connections @vmap @@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: only gene with the same key will be crossover, thus don't need to consider change key """ r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) - - -if __name__ == '__main__': - import numpy as np - - randkey = jax.random.PRNGKey(40) - nodes1 = np.array([ - [4, 1, 1, 1, 1], - [6, 2, 2, 2, 2], - [1, 3, 3, 3, 3], - [5, 4, 4, 4, 4], - [np.nan, np.nan, np.nan, np.nan, np.nan] - ]) - nodes2 = np.array([ - [4, 1.5, 1.5, 1.5, 1.5], - [7, 3.5, 3.5, 3.5, 3.5], - [5, 4.5, 4.5, 4.5, 4.5], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - ]) - weights1 = np.array([ - [ - [1, 2, 3, 4., np.nan], - [5, np.nan, 7, 8, np.nan], - [9, 10, 11, 12, np.nan], - [13, 14, 15, 16, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan] - ], - [ - [0, 1, 0, 1, np.nan], - [0, np.nan, 0, 1, np.nan], - [0, 1, 0, 1, np.nan], - [0, 1, 0, 1, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan] - ] - ]) - weights2 = np.array([ - [ - [1.5, 2.5, 3.5, np.nan, np.nan], - [3.5, 4.5, 5.5, np.nan, np.nan], - [6.5, 7.5, 8.5, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan] - ], - [ - [1, 0, 1, np.nan, np.nan], - [1, 0, 1, np.nan, np.nan], - [1, 0, 1, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan] - ] - ]) - - res = crossover(randkey, nodes1, weights1, nodes2, weights2) - print(*res, sep='\n') + return jnp.where(r > 0.5, g1, g2) \ No newline at end of file diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 898d867..b9dd00d 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -1,9 +1,7 @@ -from functools import partial - from jax import jit, vmap, Array from jax import numpy as jnp -from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis +from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON @jit @@ -14,55 +12,65 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] """ - node_distance = gene_distance(nodes1, nodes2, 'node') + nd = node_distance(nodes1, nodes2) # node distance # refactor connections keys1, keys2 = nodes1[:, 0], nodes2[:, 0] cons1 = flatten_connections(keys1, connections1) cons2 = flatten_connections(keys2, connections2) - - connection_distance = gene_distance(cons1, cons2, 'connection') - return node_distance + connection_distance + cd = connection_distance(cons1, cons2) # connection distance + return nd + cd -@partial(jit, static_argnames=["gene_type"]) -def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): - if gene_type == 'node': - keys1, keys2 = ar1[:, :1], ar2[:, :1] - else: # connection - keys1, keys2 = ar1[:, :2], ar2[:, :2] +@jit +def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): + node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) + node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) + max_cnt = jnp.maximum(node_cnt1, node_cnt2) - n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) - nodes = jnp.concatenate((ar1, ar2), axis=0) - sorted_nodes = nodes[n_sorted_indices] + nodes = jnp.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = jnp.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row + nan_mask = jnp.isnan(nodes[:, 0]) - if gene_type == 'node': - node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1) - else: - node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1) + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1] - n_intersect_mask = n_intersect_mask & node_exist_mask - n_union_mask = n_union_mask & node_exist_mask + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + nd = batch_homologous_node_distance(fr, sr) + nd = jnp.where(jnp.isnan(nd), 0, nd) + homologous_distance = jnp.sum(nd * intersect_mask) - fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] + val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe + return jnp.where(max_cnt == 0, 0, val / max_cnt) - non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask) - if gene_type == 'node': - node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) - else: # connection - node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) - node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance) - homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1]) +@jit +def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): + con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists + con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2])) + max_cnt = jnp.maximum(con_cnt1, con_cnt2) - gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1)) - gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1)) - max_cnt = jnp.maximum(gene_cnt1, gene_cnt2) + cons = jnp.concatenate((cons1, cons2), axis=0) + keys = cons[:, :2] + sorted_indices = jnp.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row + + # both genome has such connection + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2]) + + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) + cd = batch_homologous_connection_distance(fr, sr) + cd = jnp.where(jnp.isnan(cd), 0, cd) + homologous_distance = jnp.sum(cd * intersect_mask) val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene - + return jnp.where(max_cnt == 0, 0, val / max_cnt) @vmap def batch_homologous_node_distance(b_n1, b_n2): diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 0783201..013d556 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -7,12 +7,12 @@ from numpy.typing import NDArray from .aggregations import agg from .activations import act -from .graph import topological_sort, batch_topological_sort, topological_sort_debug +from .graph import topological_sort, batch_topological_sort from .utils import I_INT def create_forward_function(nodes: NDArray, connections: NDArray, - N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False): + N: int, input_idx: NDArray, output_idx: NDArray, batch: bool): """ create forward function for different situations @@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray, :return: """ - if debug: - cal_seqs = topological_sort_debug(nodes, connections) - return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx, - cal_seqs, nodes, connections) - if nodes.ndim == 2: # single genome cal_seqs = topological_sort(nodes, connections) if not batch: @@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray, raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") -# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx']) @partial(jit, static_argnames=['N']) def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, cal_seqs: Array, nodes: Array, connections: Array) -> Array: @@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, z = z * nodes[i, 2] + nodes[i, 1] z = act(nodes[i, 3], z) - # for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals - new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z)) + new_vals = carry.at[i].set(z) return new_vals def miss(): return carry - return jax.lax.cond(i == I_INT, miss, hit), None + return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) return vals[output_idx] -def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections): - ini_vals = jnp.full((N,), jnp.nan) - ini_vals = ini_vals.at[input_idx].set(inputs) - vals = ini_vals - for i in cal_seqs: - if i == I_INT: - break - ins = vals * connections[0, :, i] - z = agg(nodes[i, 4], ins) - z = z * nodes[i, 2] + nodes[i, 1] - z = act(nodes[i, 3], z) - - # for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals - vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z)) - - return vals[output_idx] - - @partial(vmap, in_axes=(0, None, None, None, None, None, None)) def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, cal_seqs: Array, nodes: Array, connections: Array) -> Array: diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 36fb8fa..8479b07 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -208,7 +208,6 @@ def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys): return res - @jit def add_node(new_node_key: int, nodes: Array, connections: Array, bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: @@ -247,7 +246,7 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra @jit def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array, - weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: + weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]: """ add a new connection to the genome. """ diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index 44752a3..85a4410 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array: return res -# @jit -def topological_sort_debug(nodes: Array, connections: Array) -> Array: - connections_enable = connections[1, :, :] == 1 - in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) - res = jnp.full(in_degree.shape, I_INT) - idx = 0 - - for _ in range(in_degree.shape[0]): - i = fetch_first(in_degree == 0.) - if i == I_INT: - break - res = res.at[idx].set(i) - idx += 1 - in_degree = in_degree.at[i].set(-1) - children = connections_enable[i, :] - in_degree = jnp.where(children, in_degree - 1, in_degree) - - return res - - @vmap def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: """ @@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: :param pop_connections: :return: """ - return topological_sort(nodes, connections) + return topological_sort(pop_nodes, pop_connections) @jit @@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra check_cycles(nodes, connections, 0, 3) -> False check_cycles(nodes, connections, 1, 0) -> False """ - # connections_enable = connections[0, :, :] == 1 connections_enable = ~jnp.isnan(connections[0, :, :]) connections_enable = connections_enable.at[from_idx, to_idx].set(True) @@ -191,7 +170,6 @@ if __name__ == '__main__': ] ) - print(topological_sort_debug(nodes, connections)) print(topological_sort(nodes, connections)) print(check_cycles(nodes, connections, 3, 2)) diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 12227c9..d92df28 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection # add two new connections weight = new_connections[0, from_idx, to_idx] new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, - new_nodes, new_connections, weight=0, enabled=True) + new_nodes, new_connections, weight=1., enabled=True) new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, new_nodes, new_connections, weight=weight, enabled=True) return new_nodes, new_connections # if from_idx == I_INT, that means no connection exist, do nothing - nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node) + nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node) return nodes, connections @@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys, allow_input_keys=False, allow_output_keys=False) - # delete the node - aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) + def nothing(): + return nodes, connections - # delete connections - aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) - aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) + def successful_delete_node(): + # delete the node + aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) - # check node_key valid - nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node - connections = jnp.where(jnp.isnan(node_key), connections, aux_connections) + # delete connections + aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) + aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) + + return aux_nodes, aux_connections + + nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) return nodes, connections @@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): def successfully_delete_connection(): return delete_connection_by_idx(from_idx, to_idx, nodes, connections) - nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection) + nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection) return nodes, connections @@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T :param connection: :return: from_key, to_key, from_idx, to_idx """ + k1, k2 = jax.random.split(rand_key, num=2) + has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1) - from_idx = fetch_random(k1, has_connections_row) - col = connection[0, from_idx, :] - to_idx = fetch_random(k2, ~jnp.isnan(col)) - from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0] - from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan) - to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan) + def nothing(): + return jnp.nan, jnp.nan, I_INT, I_INT + def has_connection(): + f_idx = fetch_random(k1, has_connections_row) + col = connection[0, f_idx, :] + t_idx = fetch_random(k2, ~jnp.isnan(col)) + f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0] + return f_key, t_key, f_idx, t_idx + + from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing) return from_key, to_key, from_idx, to_idx diff --git a/algorithms/neat/genome/numpy/distance.py b/algorithms/neat/genome/numpy/distance.py index 8677aff..9c9bc4f 100644 --- a/algorithms/neat/genome/numpy/distance.py +++ b/algorithms/neat/genome/numpy/distance.py @@ -82,44 +82,6 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): return val / max_cnt -def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): - if gene_type == 'node': - keys1, keys2 = ar1[:, :1], ar2[:, :1] - else: # connection - keys1, keys2 = ar1[:, :2], ar2[:, :2] - - n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) - nodes = np.concatenate((ar1, ar2), axis=0) - sorted_nodes = nodes[n_sorted_indices] - - if gene_type == 'node': - node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1) - else: - node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1) - - n_intersect_mask = n_intersect_mask & node_exist_mask - n_union_mask = n_union_mask & node_exist_mask - - fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] - - non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask) - if gene_type == 'node': - node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes) - else: # connection - node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes) - - node_distance = np.where(np.isnan(node_distance), 0, node_distance) - homologous_distance = np.sum(node_distance * n_intersect_mask[:-1]) - - gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1)) - gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1)) - max_cnt = np.maximum(gene_cnt1, gene_cnt2) - - val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - - return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene - - def batch_homologous_node_distance(b_n1, b_n2): res = [] for n1, n2 in zip(b_n1, b_n2): diff --git a/algorithms/neat/genome/utils.py b/algorithms/neat/genome/utils.py index ee883e0..9a72536 100644 --- a/algorithms/neat/genome/utils.py +++ b/algorithms/neat/genome/utils.py @@ -6,7 +6,8 @@ from jax import numpy as jnp, Array from jax import jit I_INT = jnp.iinfo(jnp.int32).max # infinite int - +EMPTY_NODE = jnp.full((1, 5), jnp.nan) +EMPTY_CON = jnp.full((1, 4), jnp.nan) @jit def flatten_connections(keys, connections): diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 6af06b0..157344e 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -1,12 +1,12 @@ from typing import List, Union, Tuple, Callable import time -import numpy as np +import jax from .species import SpeciesController -from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function -from .genome.numpy import batch_crossover -from .genome.numpy import expand, expand_single, pop_analysis +from .genome import create_initialize_function, create_mutate_function, create_forward_function +from .genome import batch_crossover +from .genome import expand, expand_single, pop_analysis from .genome.origin_neat import * @@ -19,7 +19,8 @@ class Pipeline: Neat algorithm pipeline. """ - def __init__(self, config): + def __init__(self, config, seed=42): + self.randkey = jax.random.PRNGKey(seed) self.config = config self.N = config.basic.init_maximum_nodes @@ -69,23 +70,23 @@ class Pipeline: self.update_next_generation(crossover_pair) - analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) + # analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) - try: - for nodes, connections in zip(self.pop_nodes, self.pop_connections): - g = array2object(self.config, nodes, connections) - print(g) - net = FeedForwardNetwork.create(g) - real_out = [net.activate(x) for x in xor_inputs] - func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True) - out = func(xor_inputs) - real_out = np.array(real_out) - out = np.array(out) - print(real_out, out) - assert np.allclose(real_out, out) - except AssertionError: - np.save("err_nodes.npy", self.pop_nodes) - np.save("err_connections.npy", self.pop_connections) + # try: + # for nodes, connections in zip(self.pop_nodes, self.pop_connections): + # g = array2object(self.config, nodes, connections) + # print(g) + # net = FeedForwardNetwork.create(g) + # real_out = [net.activate(x) for x in xor_inputs] + # func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True) + # out = func(xor_inputs) + # real_out = np.array(real_out) + # out = np.array(out) + # print(real_out, out) + # assert np.allclose(real_out, out) + # except AssertionError: + # np.save("err_nodes.npy", self.pop_nodes) + # np.save("err_connections.npy", self.pop_connections) # print(g) @@ -93,7 +94,6 @@ class Pipeline: self.expand() - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config.neat.population.generation_limit): forward_func = self.ask(batch=True) @@ -109,7 +109,6 @@ class Pipeline: self.tell(fitnesses) print("Generation limit reached!") - def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: """ create the next generation @@ -117,6 +116,7 @@ class Pipeline: """ assert self.pop_nodes.shape[0] == self.pop_size + k1, k2, self.randkey = jax.random.split(self.randkey, 3) # crossover # prepare elitism mask and crossover pair @@ -127,19 +127,20 @@ class Pipeline: crossover_pair[i] = (pair, pair) crossover_pair = np.array(crossover_pair) + crossover_rand_keys = jax.random.split(k1, self.pop_size) + # batch crossover wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections - # npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - npn, npc = batch_crossover(wpn, wpc, lpn, lpc) - # print(pop_analysis(npn, npc, self.input_idx, self.output_idx)) + npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections # mutate + mutate_rand_keys = jax.random.split(k2, self.pop_size) new_node_keys = np.array(self.fetch_new_node_keys()) - m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes + m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes # elitism don't mutate # (pop_size, ) to (pop_size, 1, 1) @@ -156,7 +157,6 @@ class Pipeline: unused.append(key) self.new_node_keys_pool = unused + self.new_node_keys_pool - def expand(self): """ Expand the population if needed. @@ -176,7 +176,6 @@ class Pipeline: for s in self.species_controller.species.values(): s.representative = expand_single(*s.representative, self.N) - def fetch_new_node_keys(self): # if remain unused keys are not enough, create new keys if len(self.new_node_keys_pool) < self.pop_size: @@ -189,7 +188,6 @@ class Pipeline: self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:] return res - def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) species_sizes = [len(s.members) for s in self.species_controller.species.values()] diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index c58654b..b224406 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -1,9 +1,11 @@ from typing import List, Tuple, Dict, Union from itertools import count +import jax import numpy as np from numpy.typing import NDArray -from .genome.numpy import distance + +from .genome import distance class Species(object): @@ -45,6 +47,9 @@ class SpeciesController: self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species + self.distance = distance + self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0)) + def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: """ :param pop_nodes: @@ -62,7 +67,7 @@ class SpeciesController: for sid, species in self.species.items(): # calculate the distance between the representative and the population r_nodes, r_connections = species.representative - distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) + distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance new_representatives[sid] = min_idx @@ -75,7 +80,7 @@ class SpeciesController: if previous_species_list: # exist previous species rid_list = [new_representatives[sid] for sid in previous_species_list] res_pop_distance = [ - o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) + self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) for rid in rid_list ] @@ -102,7 +107,7 @@ class SpeciesController: sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid ] distances = np.array(distances) @@ -314,16 +319,3 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \ sorted_members = [item[0] for item in sorted_combined] sorted_fitnesses = [item[1] for item in sorted_combined] return sorted_members, sorted_fitnesses - - -def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections): - distances = [] - for nodes, connections in zip(pop_nodes, pop_connections): - d = distance(r_nodes, r_connections, nodes, connections) - if d < 0: - d = distance(r_nodes, r_connections, nodes, connections) - print(d) - assert False - distances.append(d) - distances = np.stack(distances, axis=0) - return distances diff --git a/examples/xor.py b/examples/xor.py index 9f04f8f..cd4875c 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -17,7 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]: :return: """ outs = forward_func(xor_inputs) - fitnesses = 4 - np.sum(np.abs(outs - xor_outputs), axis=(1, 2)) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) # print(fitnesses) return fitnesses.tolist() # returns a list @@ -26,7 +26,7 @@ def evaluate(forward_func: Callable) -> List[float]: @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() - pipeline = Pipeline(config) + pipeline = Pipeline(config, seed=123123) pipeline.auto_run(evaluate) # for _ in range(100): @@ -38,5 +38,4 @@ def main(): if __name__ == '__main__': - np.random.seed(63124326) main()