From 14fed83193c7648c2c513d78bc4edd368e458b4c Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 6 May 2023 18:33:30 +0800 Subject: [PATCH] debuging --- algorithms/neat/genome/crossover.py | 4 +- algorithms/neat/genome/distance.py | 13 +++- algorithms/neat/genome/forward.py | 2 +- algorithms/neat/genome/genome.py | 93 ++++++++++++++++++++++++++--- algorithms/neat/genome/graph.py | 4 +- algorithms/neat/pipeline.py | 37 +++++++++++- algorithms/neat/species.py | 36 ++++++++++- examples/error_forward_fix.py | 24 ++++++++ examples/fix_too_large_distance.py | 11 ++++ utils/default_config.json | 4 +- 10 files changed, 206 insertions(+), 22 deletions(-) create mode 100644 examples/error_forward_fix.py create mode 100644 examples/fix_too_large_distance.py diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 5e130d9..2560491 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -42,14 +42,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] nodes2 = align_array(keys1, keys2, nodes2, 'node') - new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections cons1 = flatten_connections(keys1, connections1) cons2 = flatten_connections(keys2, connections2) con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] cons2 = align_array(con_keys1, con_keys2, cons2, 'connection') - new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) + new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) new_cons = unflatten_connections(len(keys1), new_cons) return new_nodes, new_cons diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index fd65a71..2cf8cdb 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -35,6 +35,15 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) nodes = jnp.concatenate((ar1, ar2), axis=0) sorted_nodes = nodes[n_sorted_indices] + + if gene_type == 'node': + node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1) + else: + node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1) + + n_intersect_mask = n_intersect_mask & node_exist_mask + n_union_mask = n_union_mask & node_exist_mask + fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask) @@ -48,9 +57,11 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1)) gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1)) + max_cnt = jnp.maximum(gene_cnt1, gene_cnt2) val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe - return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2) + + return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene @partial(vmap, in_axes=(0, 0)) diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 86e9116..0783201 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -27,7 +27,7 @@ def create_forward_function(nodes: NDArray, connections: NDArray, """ if debug: - cal_seqs = topological_sort(nodes, connections) + cal_seqs = topological_sort_debug(nodes, connections) return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 5e7674b..36fb8fa 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -12,9 +12,10 @@ status. Empty nodes or connections are represented using np.nan. """ -from typing import Tuple +from typing import Tuple, Dict from functools import partial +import jax import numpy as np from numpy.typing import NDArray from jax import numpy as jnp @@ -131,6 +132,83 @@ def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDA return new_nodes, new_connections + +def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \ + Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]: + """ + Convert a genome from array to dict. + :param nodes: (N, 5) + :param connections: (2, N, N) + :param output_keys: + :param input_keys: + :return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)] + """ + # update nodes_dict + try: + nodes_dict = {} + idx2key = {} + for i, node in enumerate(nodes): + if np.isnan(node[0]): + continue + key = int(node[0]) + assert key not in nodes_dict, f"Duplicate node key: {key}!" + + bias = node[1] if not np.isnan(node[1]) else None + response = node[2] if not np.isnan(node[2]) else None + act = node[3] if not np.isnan(node[3]) else None + agg = node[4] if not np.isnan(node[4]) else None + nodes_dict[key] = (bias, response, act, agg) + idx2key[i] = key + + # check nodes_dict + for i in input_keys: + assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" + bias, response, act, agg = nodes_dict[i] + assert bias is None and response is None and act is None and agg is None, \ + f"Input node {i} must has None bias, response, act, or agg!" + + for o in output_keys: + assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" + + for k, v in nodes_dict.items(): + if k not in input_keys: + bias, response, act, agg = v + assert bias is not None and response is not None and act is not None and agg is not None, \ + f"Normal node {k} must has non-None bias, response, act, or agg!" + + # update connections + connections_dict = {} + for i in range(connections.shape[1]): + for j in range(connections.shape[2]): + if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]): + continue + assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!" + assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!" + key = (idx2key[i], idx2key[j]) + + weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None + enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None + + assert weight is not None, f"Connection {key} must has non-None weight!" + assert enabled is not None, f"Connection {key} must has non-None enabled!" + connections_dict[key] = (weight, enabled) + + return nodes_dict, connections_dict + except AssertionError: + print(nodes) + print(connections) + raise AssertionError + + +def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys): + pop_nodes, pop_connections = jax.device_get((pop_nodes, pop_connections)) + res = [] + for nodes, connections in zip(pop_nodes, pop_connections): + res.append(analysis(nodes, connections, input_keys, output_keys)) + return res + + + @jit def add_node(new_node_key: int, nodes: Array, connections: Array, bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: @@ -158,11 +236,12 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra """ delete a node from the genome. only delete the node, regardless of connections. """ - node_keys = nodes[:, 0] + # node_keys = nodes[:, 0] + nodes = nodes.at[idx].set(EMPTY_NODE) # move the last node to the deleted node's position - last_real_idx = fetch_last(~jnp.isnan(node_keys)) - nodes = nodes.at[idx].set(nodes[last_real_idx]) - nodes = nodes.at[last_real_idx].set(EMPTY_NODE) + # last_real_idx = fetch_last(~jnp.isnan(node_keys)) + # nodes = nodes.at[idx].set(nodes[last_real_idx]) + # nodes = nodes.at[last_real_idx].set(EMPTY_NODE) return nodes, connections @@ -206,7 +285,3 @@ def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connectio """ connections = connections.at[:, from_idx, to_idx].set(np.nan) return nodes, connections - -# if __name__ == '__main__': -# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1) -# print(pop_nodes, pop_connections) diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index 55f8c4d..8f753dd 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -148,7 +148,9 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra check_cycles(nodes, connections, 0, 3) -> False check_cycles(nodes, connections, 1, 0) -> False """ - connections_enable = connections[1, :, :] == 1 + # connections_enable = connections[0, :, :] == 1 + connections_enable = ~jnp.isnan(connections[0, :, :]) + connections_enable = connections_enable.at[from_idx, to_idx].set(True) nodes_visited = jnp.full(nodes.shape[0], False) nodes_visited = nodes_visited.at[to_idx].set(True) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 07bb1bb..8dd0e21 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -7,7 +7,9 @@ import numpy as np from .species import SpeciesController from .genome import create_initialize_function, create_mutate_function, create_forward_function from .genome import batch_crossover +from .genome.crossover import crossover from .genome import expand, expand_single +from algorithms.neat.genome.genome import pop_analysis, analysis class Pipeline: @@ -51,12 +53,22 @@ class Pipeline: def tell(self, fitnesses): self.generation += 1 + for i, f in enumerate(fitnesses): + if np.isnan(f): + print("fuck!!!!!!!!!!!!!!") + error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i] + np.save('error_nodes.npy', error_nodes) + np.save('error_connections.npy', error_connections) + assert False + self.species_controller.update_species_fitnesses(fitnesses) crossover_pair = self.species_controller.reproduce(self.generation) self.update_next_generation(crossover_pair) + # print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)) + self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) self.expand() @@ -103,16 +115,22 @@ class Pipeline: crossover_rand_keys = jax.random.split(k1, self.pop_size) # npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc) + # print(pop_analysis(npn, npc, self.input_idx, self.output_idx)) + # mutate new_node_keys = np.array(self.fetch_new_node_keys()) mutate_rand_keys = jax.random.split(k2, self.pop_size) - m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) + m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) + + # print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx)) + # elitism don't mutate # (pop_size, ) to (pop_size, 1, 1) self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) # (pop_size, ) to (pop_size, 1, 1, 1) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + # print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)) # recycle unused node keys unused = [] @@ -138,8 +156,8 @@ class Pipeline: self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N) # don't forget to expand representation genome in species - for s in self.species_controller.species: - s.representative = expand(*s.representative, self.N) + for s in self.species_controller.species.values(): + s.representative = expand_single(*s.representative, self.N) def fetch_new_node_keys(self): # if remain unused keys are not enough, create new keys @@ -164,6 +182,19 @@ class Pipeline: print(f"Generation: {self.generation}", f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") + # def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc): + # pop_nodes, pop_connections = [], [] + # for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc): + # new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc) + # pop_nodes.append(new_nodes) + # pop_connections.append(new_connections) + # try: + # print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx)) + # except AssertionError: + # new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc) + # return np.stack(pop_nodes), np.stack(pop_connections) + + # return batch_crossover(*args) def crossover_wrapper(*args): return batch_crossover(*args) diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index a56544e..d513c82 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -67,7 +67,7 @@ class SpeciesController: for sid, species in self.species.items(): # calculate the distance between the representative and the population r_nodes, r_connections = species.representative - distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections) + distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections) distances = jax.device_get(distances) # fetch the data from gpu min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance @@ -82,7 +82,7 @@ class SpeciesController: rid_list = [new_representatives[sid] for sid in previous_species_list] res_pop_distance = [ jax.device_get( - self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) + self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) ) for rid in rid_list ] @@ -110,7 +110,7 @@ class SpeciesController: sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) for r in rid ] distances = np.array(distances) @@ -267,6 +267,36 @@ class SpeciesController: return crossover_pair + def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections): + # distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections) + # for d in distances: + # if np.isnan(d): + # print("fuck!!!!!!!!!!!!!!") + # print(distances) + # assert False + # return distances + distances = [] + for nodes, connections in zip(pop_nodes, pop_connections): + d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections) + if np.isnan(d) or d > 20: + np.save("too_large_distance_r_nodes.npy", r_nodes) + np.save("too_large_distance_r_connections.npy", r_connections) + np.save("too_large_distance_nodes", nodes) + np.save("too_large_distance_connections.npy", connections) + d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections) + assert False + distances.append(d) + distances = np.stack(distances, axis=0) + # print(distances) + return distances + + def o2o_distance_wrapper(self, *keys): + d = self.o2o_distance_func(*keys) + if np.isnan(d): + print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + assert False + return d + def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): """ diff --git a/examples/error_forward_fix.py b/examples/error_forward_fix.py new file mode 100644 index 0000000..4193a9f --- /dev/null +++ b/examples/error_forward_fix.py @@ -0,0 +1,24 @@ +import numpy as np +from jax import numpy as jnp + +from algorithms.neat.genome.genome import analysis +from algorithms.neat.genome import create_forward_function + + +error_nodes = np.load('error_nodes.npy') +error_connections = np.load('error_connections.npy') + +node_dict, connection_dict = analysis(error_nodes, error_connections, np.array([0, 1]), np.array([2, ])) +print(node_dict, connection_dict, sep='\n') + +N = error_nodes.shape[0] + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + +func = create_forward_function(error_nodes, error_connections, N, jnp.array([0, 1]), jnp.array([2, ]), + batch=True, debug=True) +out = func(np.array([1, 0])) + +print(error_nodes) +print(error_connections) +print(out) \ No newline at end of file diff --git a/examples/fix_too_large_distance.py b/examples/fix_too_large_distance.py new file mode 100644 index 0000000..e1a4b65 --- /dev/null +++ b/examples/fix_too_large_distance.py @@ -0,0 +1,11 @@ +import numpy as np +from algorithms.neat.genome import distance + +r_nodes = np.load('too_large_distance_r_nodes.npy') +r_connections = np.load('too_large_distance_r_connections.npy') +nodes = np.load('too_large_distance_nodes.npy') +connections = np.load('too_large_distance_connections.npy') + +d1 = distance(r_nodes, r_connections, nodes, connections) +d2 = distance(nodes, connections, r_nodes, r_connections) +print(d1, d2) \ No newline at end of file diff --git a/utils/default_config.json b/utils/default_config.json index a05fbff..59ff14c 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -10,7 +10,7 @@ "fitness_criterion": "max", "fitness_threshold": 3, "generation_limit": 100, - "pop_size": 20, + "pop_size": 100, "reset_on_extinction": "False" }, "gene": { @@ -73,7 +73,7 @@ "node_delete_prob": 0.2 }, "species": { - "compatibility_threshold": 8, + "compatibility_threshold": 3.5, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,