diff --git a/algorithms/neat/__init__.py b/algorithms/neat/__init__.py index 71ad225..445a48a 100644 --- a/algorithms/neat/__init__.py +++ b/algorithms/neat/__init__.py @@ -6,3 +6,5 @@ from .population import update_species, create_next_generation, speciate from .genome.activations import act_name2func from .genome.aggregations import agg_name2func + +from .visualize import Genome diff --git a/algorithms/neat/genome/debug/__init__.py b/algorithms/neat/genome/debug/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/algorithms/neat/genome/debug/tools.py b/algorithms/neat/genome/debug/tools.py deleted file mode 100644 index 34de771..0000000 --- a/algorithms/neat/genome/debug/tools.py +++ /dev/null @@ -1,88 +0,0 @@ -from collections import defaultdict - -import numpy as np - - -def check_array_valid(nodes, cons, input_keys, output_keys): - nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys) - # assert is_DAG(cons_dict.keys()), "The genome is not a DAG!" - - -def array2object(nodes, cons, input_keys, output_keys): - """ - Convert a genome from array to dict. - :param nodes: (N, 5) - :param cons: (C, 4) - :param output_keys: - :param input_keys: - :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] - """ - # update nodes_dict - nodes_dict = {} - for i, node in enumerate(nodes): - if np.isnan(node[0]): - continue - key = int(node[0]) - assert key not in nodes_dict, f"Duplicate node key: {key}!" - - if key in input_keys: - assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!" - nodes_dict[key] = (None,) * 4 - else: - assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!" - bias = node[1] - response = node[2] - act = node[3] - agg = node[4] - nodes_dict[key] = (bias, response, act, agg) - - # check nodes_dict - for i in input_keys: - assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" - - for o in output_keys: - assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" - - # update connections - cons_dict = {} - for i, con in enumerate(cons): - if np.all(np.isnan(con)): - pass - elif np.all(~np.isnan(con)): - i_key = int(con[0]) - o_key = int(con[1]) - if (i_key, o_key) in cons_dict: - assert False, f"Duplicate connection: {(i_key, o_key)}!" - assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" - assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" - weight = con[2] - enabled = (con[3] == 1) - cons_dict[(i_key, o_key)] = (weight, enabled) - else: - assert False, f"Connection {i} must has all None or all non-None!" - - return nodes_dict, cons_dict - - -def is_DAG(edges): - all_nodes = set() - for a, b in edges: - if a == b: # cycle - return False - all_nodes.union({a, b}) - - for node in all_nodes: - visited = {n: False for n in all_nodes} - def dfs(n): - if visited[n]: - return False - visited[n] = True - for a, b in edges: - if a == n: - if not dfs(b): - return False - return True - - if not dfs(node): - return False - return True diff --git a/algorithms/neat/visualize.py b/algorithms/neat/visualize.py new file mode 100644 index 0000000..15e5dec --- /dev/null +++ b/algorithms/neat/visualize.py @@ -0,0 +1,112 @@ +import jax +import numpy as np + + +class Genome: + def __init__(self, nodes, cons, config): + self.config = config + self.nodes, self.cons = array2object(nodes, cons, config) + if config['renumber_nodes']: + self.renumber() + + def __repr__(self): + return f'Genome(\n' \ + f'\tinput_keys: {self.config["input_idx"]}, \n' \ + f'\toutput_keys: {self.config["output_idx"]}, \n' \ + f'\tnodes: \n\t\t' \ + f'{self.repr_nodes()} \n' \ + f'\tconnections: \n\t\t' \ + f'{self.repr_conns()} \n)' + + def repr_nodes(self): + nodes_info = [] + for key, value in self.nodes.items(): + bias, response, act, agg = value + act_func = self.config['activation_option_names'][int(act)] if act is not None else None + agg_func = self.config['aggregation_option_names'][int(agg)] if agg is not None else None + s = f"{key}: (bias: {bias}, response: {response}, act: {act_func}, agg: {agg_func})" + nodes_info.append(s) + return ',\n\t\t'.join(nodes_info) + + def repr_conns(self): + conns_info = [] + for key, value in self.cons.items(): + weight, enabled = value + s = f"{key}: (weight: {weight}, enabled: {enabled})" + conns_info.append(s) + return ',\n\t\t'.join(conns_info) + + def renumber(self): + nodes2new_nodes = {} + new_id = len(self.config['input_idx']) + len(self.config['output_idx']) + for key in self.nodes.keys(): + if key in self.config['input_idx'] or key in self.config['output_idx']: + nodes2new_nodes[key] = key + else: + nodes2new_nodes[key] = new_id + new_id += 1 + + new_nodes, new_cons = {}, {} + for key, value in self.nodes.items(): + new_nodes[nodes2new_nodes[key]] = value + for key, value in self.cons.items(): + i_key, o_key = key + new_cons[(nodes2new_nodes[i_key], nodes2new_nodes[o_key])] = value + self.nodes = new_nodes + self.cons = new_cons + + +def array2object(nodes, cons, config): + """ + Convert a genome from array to dict. + :param nodes: (N, 5) + :param cons: (C, 4) + :return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)] + """ + nodes, cons = jax.device_get((nodes, cons)) + # update nodes_dict + nodes_dict = {} + for i, node in enumerate(nodes): + if np.isnan(node[0]): + continue + key = int(node[0]) + assert key not in nodes_dict, f"Duplicate node key: {key}!" + + if key in config['input_idx']: + assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!" + nodes_dict[key] = (None,) * 4 + else: + assert np.all( + ~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!" + bias = node[1] + response = node[2] + act = node[3] + agg = node[4] + nodes_dict[key] = (bias, response, act, agg) + + # check nodes_dict + for i in config['input_idx']: + assert i in nodes_dict, f"Input node {i} not found in nodes_dict!" + + for o in config['output_idx']: + assert o in nodes_dict, f"Output node {o} not found in nodes_dict!" + + # update connections + cons_dict = {} + for i, con in enumerate(cons): + if np.all(np.isnan(con)): + pass + elif np.all(~np.isnan(con)): + i_key = int(con[0]) + o_key = int(con[1]) + if (i_key, o_key) in cons_dict: + assert False, f"Duplicate connection: {(i_key, o_key)}!" + assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!" + assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!" + weight = con[2] + enabled = (con[3] == 1) + cons_dict[(i_key, o_key)] = (weight, enabled) + else: + assert False, f"Connection {i} must has all None or all non-None!" + + return nodes_dict, cons_dict diff --git a/configs/default_config.ini b/configs/default_config.ini index 25c3ac3..7bb4244 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -65,3 +65,6 @@ weight_replace_rate = 0.1 [gene-enable] enable_mutate_rate = 0.01 + +[visualize] +renumber_nodes = True \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 0bded6d..fc5369e 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,8 +1,8 @@ import numpy as np from configs import Configer -from algorithms.neat.pipeline import Pipeline - +from algorithms.neat import Genome +from pipeline import Pipeline xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -22,8 +22,8 @@ def main(): config = Configer.load_config("xor.ini") pipeline = Pipeline(config, seed=6) nodes, cons = pipeline.auto_run(evaluate) - print(nodes) - print(cons) + g = Genome(nodes, cons, config) + print(g) if __name__ == '__main__': diff --git a/algorithms/neat/pipeline.py b/pipeline.py similarity index 100% rename from algorithms/neat/pipeline.py rename to pipeline.py