From f4763ebceab6b3004c90f1287d410c95a00ba447 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 17 Jul 2023 17:39:12 +0800 Subject: [PATCH] change a lot --- algorithm/__init__.py | 2 + algorithm/config.py | 118 +++++++++++++++++ algorithm/default_config.ini | 74 +++++++++++ algorithm/neat/NEAT.py | 42 ++++++ algorithm/neat/__init__.py | 1 + algorithm/neat/gene/__init__.py | 2 + algorithm/neat/gene/activation.py | 108 +++++++++++++++ algorithm/neat/gene/aggregation.py | 60 +++++++++ algorithm/neat/gene/base.py | 38 ++++++ algorithm/neat/gene/normal.py | 40 ++++++ algorithm/neat/genome/__init__.py | 2 + algorithm/neat/genome/basic.py | 102 ++++++++++++++ algorithm/neat/genome/crossover.py | 0 algorithm/neat/genome/distance.py | 0 algorithm/neat/genome/graph.py | 167 +++++++++++++++++++++++ algorithm/neat/genome/mutate.py | 206 +++++++++++++++++++++++++++++ algorithm/neat/utils.py | 71 ++++++++++ algorithm/state.py | 4 +- examples/config_test.py | 4 + examples/state_test.py | 7 +- examples/xor_test.py | 16 +++ 21 files changed, 1060 insertions(+), 4 deletions(-) create mode 100644 algorithm/config.py create mode 100644 algorithm/default_config.ini create mode 100644 algorithm/neat/NEAT.py create mode 100644 algorithm/neat/gene/__init__.py create mode 100644 algorithm/neat/gene/activation.py create mode 100644 algorithm/neat/gene/aggregation.py create mode 100644 algorithm/neat/gene/base.py create mode 100644 algorithm/neat/gene/normal.py create mode 100644 algorithm/neat/genome/__init__.py create mode 100644 algorithm/neat/genome/basic.py create mode 100644 algorithm/neat/genome/crossover.py create mode 100644 algorithm/neat/genome/distance.py create mode 100644 algorithm/neat/genome/graph.py create mode 100644 algorithm/neat/genome/mutate.py create mode 100644 algorithm/neat/utils.py create mode 100644 examples/config_test.py create mode 100644 examples/xor_test.py diff --git a/algorithm/__init__.py b/algorithm/__init__.py index e69de29..1703af4 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -0,0 +1,2 @@ +from .state import State +from .neat import NEAT diff --git a/algorithm/config.py b/algorithm/config.py new file mode 100644 index 0000000..4b8946b --- /dev/null +++ b/algorithm/config.py @@ -0,0 +1,118 @@ +import os +import warnings +import configparser + +import numpy as np + +# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. +jit_config_keys = [ + "input_idx", + "output_idx", + "compatibility_disjoint", + "compatibility_weight", + "conn_add_prob", + "conn_add_trials", + "conn_delete_prob", + "node_add_prob", + "node_delete_prob", + "compatibility_threshold", + "bias_init_mean", + "bias_init_std", + "bias_mutate_power", + "bias_mutate_rate", + "bias_replace_rate", + "response_init_mean", + "response_init_std", + "response_mutate_power", + "response_mutate_rate", + "response_replace_rate", + "activation_default", + "activation_options", + "activation_replace_rate", + "aggregation_default", + "aggregation_options", + "aggregation_replace_rate", + "weight_init_mean", + "weight_init_std", + "weight_mutate_power", + "weight_mutate_rate", + "weight_replace_rate", + "enable_mutate_rate", + "max_stagnation", + "pop_size", + "genome_elitism", + "survival_threshold", + "species_elitism", + "spawn_number_move_rate" +] + + +class Configer: + + @classmethod + def __load_default_config(cls): + par_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join(par_dir, "default_config.ini") + return cls.__load_config(default_config_path) + + @classmethod + def __load_config(cls, config_path): + c = configparser.ConfigParser() + c.read(config_path) + config = {} + + for section in c.sections(): + for key, value in c.items(section): + config[key] = eval(value) + + return config + + @classmethod + def __check_redundant_config(cls, default_config, config): + for key in config: + if key not in default_config: + warnings.warn(f"Redundant config: {key} in {config.name}") + + @classmethod + def __complete_config(cls, default_config, config): + for key in default_config: + if key not in config: + config[key] = default_config[key] + + @classmethod + def load_config(cls, config_path=None): + default_config = cls.__load_default_config() + if config_path is None: + config = {} + elif not os.path.exists(config_path): + warnings.warn(f"config file {config_path} not exist!") + config = {} + else: + config = cls.__load_config(config_path) + + cls.__check_redundant_config(default_config, config) + cls.__complete_config(default_config, config) + + cls.refactor_activation(config) + cls.refactor_aggregation(config) + + config['input_idx'] = np.arange(config['num_inputs']) + config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) + + return config + + @classmethod + def refactor_activation(cls, config): + config['activation_default'] = 0 + config['activation_options'] = np.arange(len(config['activation_option_names'])) + + @classmethod + def refactor_aggregation(cls, config): + config['aggregation_default'] = 0 + config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) + + @classmethod + def create_jit_config(cls, config): + jit_config = {k: config[k] for k in jit_config_keys} + + return jit_config diff --git a/algorithm/default_config.ini b/algorithm/default_config.ini new file mode 100644 index 0000000..8913b75 --- /dev/null +++ b/algorithm/default_config.ini @@ -0,0 +1,74 @@ +[basic] +num_inputs = 2 +num_outputs = 1 +maximum_nodes = 5 +maximum_connections = 5 +maximum_species = 10 +forward_way = "pop" +batch_size = 4 +random_seed = 0 +network_type = 'feedforward' + +[population] +fitness_threshold = 3.99999 +generation_limit = 1000 +fitness_criterion = "max" +pop_size = 1000 + +[gene] +gene_type = "normal" + +[genome] +compatibility_disjoint = 1.0 +compatibility_weight = 0.5 +conn_add_prob = 0.4 +conn_add_trials = 1 +conn_delete_prob = 0.4 +node_add_prob = 0.2 +node_delete_prob = 0.2 + +[species] +compatibility_threshold = 3.0 +species_elitism = 2 +max_stagnation = 15 +genome_elitism = 2 +survival_threshold = 0.2 +min_species_size = 1 +spawn_number_move_rate = 0.5 + +[gene-bias] +bias_init_mean = 0.0 +bias_init_std = 1.0 +bias_mutate_power = 0.5 +bias_mutate_rate = 0.7 +bias_replace_rate = 0.1 + +[gene-response] +response_init_mean = 1.0 +response_init_std = 0.0 +response_mutate_power = 0.0 +response_mutate_rate = 0.0 +response_replace_rate = 0.0 + +[gene-activation] +activation_default = "sigmoid" +activation_option_names = ["sigmoid"] +activation_replace_rate = 0.0 + +[gene-aggregation] +aggregation_default = "sum" +aggregation_option_names = ["sum"] +aggregation_replace_rate = 0.0 + +[gene-weight] +weight_init_mean = 0.0 +weight_init_std = 1.0 +weight_mutate_power = 0.5 +weight_mutate_rate = 0.8 +weight_replace_rate = 0.1 + +[gene-enable] +enable_mutate_rate = 0.01 + +[visualize] +renumber_nodes = True \ No newline at end of file diff --git a/algorithm/neat/NEAT.py b/algorithm/neat/NEAT.py new file mode 100644 index 0000000..ff5baa9 --- /dev/null +++ b/algorithm/neat/NEAT.py @@ -0,0 +1,42 @@ +from algorithm.state import State +from .gene import * +from .genome import initialize_genomes + + +class NEAT: + def __init__(self, config): + self.config = config + if self.config['gene_type'] == 'normal': + self.gene_type = NormalGene + else: + raise NotImplementedError + + def setup(self, randkey): + + state = State( + randkey=randkey, + P=self.config['pop_size'], + N=self.config['maximum_nodes'], + C=self.config['maximum_connections'], + S=self.config['maximum_species'], + NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes + CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes + input_idx=self.config['input_idx'], + output_idx=self.config['output_idx'] + ) + + pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) + next_node_key = max(*state.input_idx, *state.output_idx) + 2 + state = state.update( + pop_nodes=pop_nodes, + pop_conns=pop_conns, + next_node_key=next_node_key + ) + + return state + + def tell(self, state, fitness): + return State() + + def ask(self, state): + return State() diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index e69de29..b4d1a48 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -0,0 +1 @@ +from .NEAT import NEAT diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py new file mode 100644 index 0000000..aebf638 --- /dev/null +++ b/algorithm/neat/gene/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseGene +from .normal import NormalGene diff --git a/algorithm/neat/gene/activation.py b/algorithm/neat/gene/activation.py new file mode 100644 index 0000000..a8075e0 --- /dev/null +++ b/algorithm/neat/gene/activation.py @@ -0,0 +1,108 @@ +import jax.numpy as jnp + + +class Activation: + + @staticmethod + def sigmoid_act(z): + z = jnp.clip(z * 5, -60, 60) + return 1 / (1 + jnp.exp(-z)) + + @staticmethod + def tanh_act(z): + z = jnp.clip(z * 2.5, -60, 60) + return jnp.tanh(z) + + @staticmethod + def sin_act(z): + z = jnp.clip(z * 5, -60, 60) + return jnp.sin(z) + + @staticmethod + def gauss_act(z): + z = jnp.clip(z * 5, -3.4, 3.4) + return jnp.exp(-z ** 2) + + @staticmethod + def relu_act(z): + return jnp.maximum(z, 0) + + @staticmethod + def elu_act(z): + return jnp.where(z > 0, z, jnp.exp(z) - 1) + + @staticmethod + def lelu_act(z): + leaky = 0.005 + return jnp.where(z > 0, z, leaky * z) + + @staticmethod + def selu_act(z): + lam = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) + + @staticmethod + def softplus_act(z): + z = jnp.clip(z * 5, -60, 60) + return 0.2 * jnp.log(1 + jnp.exp(z)) + + @staticmethod + def identity_act(z): + return z + + @staticmethod + def clamped_act(z): + return jnp.clip(z, -1, 1) + + @staticmethod + def inv_act(z): + z = jnp.maximum(z, 1e-7) + return 1 / z + + @staticmethod + def log_act(z): + z = jnp.maximum(z, 1e-7) + return jnp.log(z) + + @staticmethod + def exp_act(z): + z = jnp.clip(z, -60, 60) + return jnp.exp(z) + + @staticmethod + def abs_act(z): + return jnp.abs(z) + + @staticmethod + def hat_act(z): + return jnp.maximum(0, 1 - jnp.abs(z)) + + @staticmethod + def square_act(z): + return z ** 2 + + @staticmethod + def cube_act(z): + return z ** 3 + + name2func = { + 'sigmoid': sigmoid_act, + 'tanh': tanh_act, + 'sin': sin_act, + 'gauss': gauss_act, + 'relu': relu_act, + 'elu': elu_act, + 'lelu': lelu_act, + 'selu': selu_act, + 'softplus': softplus_act, + 'identity': identity_act, + 'clamped': clamped_act, + 'inv': inv_act, + 'log': log_act, + 'exp': exp_act, + 'abs': abs_act, + 'hat': hat_act, + 'square': square_act, + 'cube': cube_act, + } diff --git a/algorithm/neat/gene/aggregation.py b/algorithm/neat/gene/aggregation.py new file mode 100644 index 0000000..c03d960 --- /dev/null +++ b/algorithm/neat/gene/aggregation.py @@ -0,0 +1,60 @@ +import jax.numpy as jnp + + +class Aggregation: + + @staticmethod + def sum_agg(z): + z = jnp.where(jnp.isnan(z), 0, z) + return jnp.sum(z, axis=0) + + @staticmethod + def product_agg(z): + z = jnp.where(jnp.isnan(z), 1, z) + return jnp.prod(z, axis=0) + + @staticmethod + def max_agg(z): + z = jnp.where(jnp.isnan(z), -jnp.inf, z) + return jnp.max(z, axis=0) + + @staticmethod + def min_agg(z): + z = jnp.where(jnp.isnan(z), jnp.inf, z) + return jnp.min(z, axis=0) + + @staticmethod + def maxabs_agg(z): + z = jnp.where(jnp.isnan(z), 0, z) + abs_z = jnp.abs(z) + max_abs_index = jnp.argmax(abs_z) + return z[max_abs_index] + + @staticmethod + def median_agg(z): + n = jnp.sum(~jnp.isnan(z), axis=0) + + z = jnp.sort(z) # sort + + idx1, idx2 = (n - 1) // 2, n // 2 + median = (z[idx1] + z[idx2]) / 2 + + return median + + @staticmethod + def mean_agg(z): + aux = jnp.where(jnp.isnan(z), 0, z) + valid_values_sum = jnp.sum(aux, axis=0) + valid_values_count = jnp.sum(~jnp.isnan(z), axis=0) + mean_without_zeros = valid_values_sum / valid_values_count + return mean_without_zeros + + name2func = { + 'sum': sum_agg, + 'product': product_agg, + 'max': max_agg, + 'min': min_agg, + 'maxabs': maxabs_agg, + 'median': median_agg, + 'mean': mean_agg, + } diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py new file mode 100644 index 0000000..40d2c10 --- /dev/null +++ b/algorithm/neat/gene/base.py @@ -0,0 +1,38 @@ +from jax import Array, numpy as jnp + + +class BaseGene: + node_attrs = [] + conn_attrs = [] + + @staticmethod + def setup(state, config): + return state + + @staticmethod + def new_node_attrs(state): + return jnp.zeros(0) + + @staticmethod + def new_conn_attrs(state): + return jnp.zeros(0) + + @staticmethod + def mutate_node(state, attrs: Array, key): + return attrs + + @staticmethod + def mutate_conn(state, attrs: Array, key): + return attrs + + @staticmethod + def distance_node(state, array: Array): + return array + + @staticmethod + def distance_conn(state, array: Array): + return array + + @staticmethod + def forward(state, array: Array): + return array diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py new file mode 100644 index 0000000..799012e --- /dev/null +++ b/algorithm/neat/gene/normal.py @@ -0,0 +1,40 @@ +from jax import Array, numpy as jnp + +from . import BaseGene + + +class NormalGene(BaseGene): + node_attrs = ['bias', 'response', 'aggregation', 'activation'] + conn_attrs = ['weight'] + + @staticmethod + def setup(state, config): + return state + + @staticmethod + def new_node_attrs(state): + return jnp.array([0, 0, 0, 0]) + + @staticmethod + def new_conn_attrs(state): + return jnp.array([0]) + + @staticmethod + def mutate_node(state, attrs: Array, key): + return attrs + + @staticmethod + def mutate_conn(state, attrs: Array, key): + return attrs + + @staticmethod + def distance_node(state, array: Array): + return array + + @staticmethod + def distance_conn(state, array: Array): + return array + + @staticmethod + def forward(state, array: Array): + return array diff --git a/algorithm/neat/genome/__init__.py b/algorithm/neat/genome/__init__.py new file mode 100644 index 0000000..1cc30c1 --- /dev/null +++ b/algorithm/neat/genome/__init__.py @@ -0,0 +1,2 @@ +from .basic import initialize_genomes +from .mutate import create_mutate diff --git a/algorithm/neat/genome/basic.py b/algorithm/neat/genome/basic.py new file mode 100644 index 0000000..915eebc --- /dev/null +++ b/algorithm/neat/genome/basic.py @@ -0,0 +1,102 @@ +from typing import Type, Tuple + +import numpy as np +import jax +from jax import Array, numpy as jnp + +from algorithm import State +from ..gene import BaseGene +from ..utils import fetch_first + + +def initialize_genomes(state: State, gene_type: Type[BaseGene]): + o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes + o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections + + input_idx = state.input_idx + output_idx = state.output_idx + new_node_key = max([*input_idx, *output_idx]) + 1 + + o_nodes[input_idx, 0] = input_idx + o_nodes[output_idx, 0] = output_idx + o_nodes[new_node_key, 0] = new_node_key + o_nodes[np.concatenate([input_idx, output_idx]), 1:] = jax.device_get(gene_type.new_node_attrs(state)) + o_nodes[new_node_key, 1:] = jax.device_get(gene_type.new_node_attrs(state)) + + input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] + o_conns[input_idx, 0:2] = input_conns # in key, out key + o_conns[input_idx, 2] = True # enabled + o_conns[input_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state)) + + output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] + o_conns[output_idx, 0:2] = output_conns # in key, out key + o_conns[output_idx, 2] = True # enabled + o_conns[output_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state)) + + # repeat origin genome for P times to create population + pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) + pop_conns = np.tile(o_conns, (state.P, 1, 1)) + + return pop_nodes, pop_conns + + +def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]: + """ + Add a new node to the genome. + The new node will place at the first NaN row. + """ + exist_keys = nodes[:, 0] + idx = fetch_first(jnp.isnan(exist_keys)) + nodes = nodes.at[idx, 0].set(new_key) + nodes = nodes.at[idx, 1:].set(attrs) + return nodes, cons + + +def delete_node(nodes: Array, cons: Array, node_key: Array) -> Tuple[Array, Array]: + """ + Delete a node from the genome. Only delete the node, regardless of connections. + Delete the node by its key. + """ + node_keys = nodes[:, 0] + idx = fetch_first(node_keys == node_key) + return delete_node_by_idx(nodes, cons, idx) + + +def delete_node_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]: + """ + Delete a node from the genome. Only delete the node, regardless of connections. + Delete the node by its idx. + """ + nodes = nodes.at[idx].set(np.nan) + return nodes, cons + + +def add_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array, enable: bool, attrs: Array) -> Tuple[ + Array, Array]: + """ + Add a new connection to the genome. + The new connection will place at the first NaN row. + """ + con_keys = cons[:, 0] + idx = fetch_first(jnp.isnan(con_keys)) + cons = cons.at[idx, 0:3].set(jnp.array([i_key, o_key, enable])) + cons = cons.at[idx, 3:].set(attrs) + return nodes, cons + + +def delete_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array) -> Tuple[Array, Array]: + """ + Delete a connection from the genome. + Delete the connection by its input and output node keys. + """ + idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + return delete_connection_by_idx(nodes, cons, idx) + + +def delete_connection_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]: + """ + Delete a connection from the genome. + Delete the connection by its idx. + """ + cons = cons.at[idx].set(np.nan) + return nodes, cons diff --git a/algorithm/neat/genome/crossover.py b/algorithm/neat/genome/crossover.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm/neat/genome/distance.py b/algorithm/neat/genome/distance.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm/neat/genome/graph.py b/algorithm/neat/genome/graph.py new file mode 100644 index 0000000..79d1810 --- /dev/null +++ b/algorithm/neat/genome/graph.py @@ -0,0 +1,167 @@ +""" +Some graph algorithm implemented in jax. +Only used in feed-forward networks. +""" + +import jax +from jax import jit, Array, numpy as jnp + +from ..utils import fetch_first, I_INT + + +@jit +def topological_sort(nodes: Array, connections: Array) -> Array: + """ + a jit-able version of topological_sort! that's crazy! + :param nodes: nodes array + :param connections: connections array + :return: topological sorted sequence + + Example: + nodes = jnp.array([ + [0], + [1], + [2], + [3] + ]) + connections = jnp.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + topological_sort(nodes, connections) -> [0, 1, 2, 3] + """ + connections_enable = connections[1, :, :] == 1 # forward function. thus use enable + in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) + res = jnp.full(in_degree.shape, I_INT) + + def cond_fun(carry): + res_, idx_, in_degree_ = carry + i = fetch_first(in_degree_ == 0.) + return i != I_INT + + def body_func(carry): + res_, idx_, in_degree_ = carry + i = fetch_first(in_degree_ == 0.) + + # add to res and flag it is already in it + res_ = res_.at[idx_].set(i) + in_degree_ = in_degree_.at[i].set(-1) + + # decrease in_degree of all its children + children = connections_enable[i, :] + in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_) + return res_, idx_ + 1, in_degree_ + + res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree)) + return res + + +@jit +def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: + """ + Check whether a new connection (from_idx -> to_idx) will cause a cycle. + + :param nodes: JAX array + The array of nodes. + :param connections: JAX array + The array of connections. + :param from_idx: int + The index of the starting node. + :param to_idx: int + The index of the ending node. + :return: JAX array + An array indicating if there is a cycle caused by the new connection. + + Example: + nodes = jnp.array([ + [0], + [1], + [2], + [3] + ]) + connections = jnp.array([ + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ], + [ + [0, 0, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0] + ] + ]) + + check_cycles(nodes, connections, 3, 2) -> True + check_cycles(nodes, connections, 2, 3) -> False + check_cycles(nodes, connections, 0, 3) -> False + check_cycles(nodes, connections, 1, 0) -> False + """ + + connections_enable = ~jnp.isnan(connections[0, :, :]) + connections_enable = connections_enable.at[from_idx, to_idx].set(True) + + visited = jnp.full(nodes.shape[0], False) + new_visited = visited.at[to_idx].set(True) + + def cond_func(carry): + visited_, new_visited_ = carry + end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited + end_cond2 = new_visited_[from_idx] # the starting node has been visited + return jnp.logical_not(end_cond1 | end_cond2) + + def body_func(carry): + _, visited_ = carry + new_visited_ = jnp.dot(visited_, connections_enable) + new_visited_ = jnp.logical_or(visited_, new_visited_) + return visited_, new_visited_ + + _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) + return visited[from_idx] + + +if __name__ == '__main__': + nodes = jnp.array([ + [0], + [1], + [2], + [3], + [jnp.nan] + ]) + connections = jnp.array([ + [ + [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, 1, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] + ], + [ + [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, 1, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] + ] + ] + ) + + print(topological_sort(nodes, connections)) + + print(check_cycles(nodes, connections, 3, 2)) + print(check_cycles(nodes, connections, 2, 3)) + print(check_cycles(nodes, connections, 0, 3)) + print(check_cycles(nodes, connections, 1, 0)) \ No newline at end of file diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py new file mode 100644 index 0000000..748f8be --- /dev/null +++ b/algorithm/neat/genome/mutate.py @@ -0,0 +1,206 @@ +from typing import Dict, Tuple, Type + +import numpy as np +import jax +from jax import Array, numpy as jnp, vmap + +from algorithm import State +from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx +from .graph import check_cycles +from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections +from ..gene import BaseGene + + +def create_mutate(config: Dict, gene_type: Type[BaseGene]): + """ + Create function to mutate the whole population + """ + + def mutate_structure(state: State, randkey, nodes, cons, new_node_key): + def nothing(*args): + return nodes, cons + + def mutate_add_node(key_): + i_key, o_key, idx = choice_connection_key(key_, nodes, cons) + + def successful_add_node(): + # disable the connection + aux_nodes, aux_cons = nodes, cons + + # set enable to false + aux_cons = aux_cons.at[idx, 2].set(False) + + # add a new node + aux_nodes, aux_cons = add_node(aux_nodes, aux_cons, new_node_key, gene_type.new_node_attrs(state)) + + # add two new connections + aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, i_key, new_node_key, True, + gene_type.new_conn_attrs(state)) + aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, new_node_key, o_key, True, + gene_type.new_conn_attrs(state)) + + return aux_nodes, aux_cons + + # if from_idx == I_INT, that means no connection exist, do nothing + return jax.lax.cond(idx == I_INT, nothing, successful_add_node) + + def mutate_delete_node(key_): + # TODO: Do we really need to delete a node? + # randomly choose a node + key, idx = choice_node_key(key_, nodes, config['input_idx'], config['output_idx'], + allow_input_keys=False, allow_output_keys=False) + + def successful_delete_node(): + # delete the node + aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx) + + # delete all connections + aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], + jnp.nan, aux_cons) + + return aux_nodes, aux_cons + + return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) + + def mutate_add_conn(key_): + # randomly choose two nodes + k1_, k2_ = jax.random.split(key_, num=2) + i_key, from_idx = choice_node_key(k1_, nodes, config['input_idx'], config['output_idx'], + allow_input_keys=True, allow_output_keys=True) + o_key, to_idx = choice_node_key(k2_, nodes, config['input_idx'], config['output_idx'], + allow_input_keys=False, allow_output_keys=True) + + con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + + def successful(): + new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, True, gene_type.new_conn_attrs(state)) + return new_nodes, new_cons + + def already_exist(): + new_cons = cons.at[con_idx, 2].set(True) + return nodes, new_cons + + is_already_exist = con_idx != I_INT + + if config['network_type'] == 'feedforward': + u_cons = unflatten_connections(nodes, cons) + is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx) + + choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) + return jax.lax.switch(choice, [already_exist, nothing, successful]) + + elif config['network_type'] == 'recurrent': + return jax.lax.cond(is_already_exist, already_exist, successful) + + else: + raise ValueError(f"Invalid network type: {config['network_type']}") + + def mutate_delete_conn(key_): + # randomly choose a connection + i_key, o_key, idx = choice_connection_key(key_, nodes, cons) + + def successfully_delete_connection(): + return delete_connection_by_idx(nodes, cons, idx) + + return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) + + k, k1, k2, k3, k4 = jax.random.split(randkey, num=5) + r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) + + nodes, cons = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, nothing, k1) + nodes, cons = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, nothing, k2) + nodes, cons = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, nothing, k3) + nodes, cons = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, nothing, k4) + return nodes, cons + + def mutate_values(state: State, randkey, nodes, conns): + k1, k2 = jax.random.split(randkey, num=2) + nodes_keys = jax.random.split(k1, num=nodes.shape[0]) + conns_keys = jax.random.split(k2, num=conns.shape[0]) + + nodes_attrs, conns_attrs = nodes[:, 1:], conns[:, 3:] + + new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys) + new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys) + + # nan nodes not changed + new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) + new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) + + new_nodes = nodes.at[:, 1:].set(new_nodes_attrs) + new_conns = conns.at[:, 3:].set(new_conns_attrs) + + return new_nodes, new_conns + + def mutate(state): + pop_nodes, pop_conns = state.pop_nodes, state.pop_conns + pop_size = pop_nodes.shape[0] + + new_node_keys = jnp.arange(pop_size) + state.next_node_key + k1, k2, randkey = jax.random.split(state.randkey, num=3) + structure_randkeys = jax.random.split(k1, num=pop_size) + values_randkeys = jax.random.split(k2, num=pop_size) + + structure_func = jax.vmap(mutate_structure, in_axes=(None, 0, 0, 0, 0)) + pop_nodes, pop_conns = structure_func(state, structure_randkeys, pop_nodes, pop_conns, new_node_keys) + + values_func = jax.vmap(mutate_values, in_axes=(None, 0, 0, 0)) + pop_nodes, pop_conns = values_func(state, values_randkeys, pop_nodes, pop_conns) + + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return state.update( + pop_nodes=pop_nodes, + pop_conns=pop_conns, + next_node_key=next_node_key, + randkey=randkey + ) + + return mutate + + +def choice_node_key(rand_key: Array, nodes: Array, + input_keys: Array, output_keys: Array, + allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param rand_key: + :param nodes: + :param input_keys: + :param output_keys: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~jnp.isnan(node_keys) + + if not allow_input_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) + + if not allow_output_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) + + idx = fetch_random(rand_key, mask) + key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + return key, idx + + +def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: + """ + Randomly choose a connection key from the given connections. + :param rand_key: + :param nodes: + :param cons: + :return: i_key, o_key, idx + """ + + idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) + i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) + + return i_key, o_key, idx diff --git a/algorithm/neat/utils.py b/algorithm/neat/utils.py new file mode 100644 index 0000000..f9b5364 --- /dev/null +++ b/algorithm/neat/utils.py @@ -0,0 +1,71 @@ +from functools import partial + +import numpy as np +import jax +from jax import numpy as jnp, Array, jit, vmap + +I_INT = np.iinfo(jnp.int32).max # infinite int +EMPTY_NODE = np.full((1, 5), jnp.nan) +EMPTY_CON = np.full((1, 4), jnp.nan) + + +@jit +def unflatten_connections(nodes: Array, cons: Array): + """ + transform the (C, 4) connections to (2, N, N) + :param nodes: (N, 5) + :param cons: (C, 4) + :return: + """ + N = nodes.shape[0] + node_keys = nodes[:, 0] + i_keys, o_keys = cons[:, 0], cons[:, 1] + i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) + o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) + res = jnp.full((2, N, N), jnp.nan) + + # Is interesting that jax use clip when attach data in array + # however, it will do nothing set values in an array + res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) + res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) + + return res + + +def key_to_indices(key, keys): + return fetch_first(key == keys) + + +@jit +def fetch_first(mask, default=I_INT) -> Array: + """ + fetch the first True index + :param mask: array of bool + :param default: the default value if no element satisfying the condition + :return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value + """ + idx = jnp.argmax(mask) + return jnp.where(mask[idx], idx, default) + + +@jit +def fetch_random(rand_key, mask, default=I_INT) -> Array: + """ + similar to fetch_first, but fetch a random True index + """ + true_cnt = jnp.sum(mask) + cumsum = jnp.cumsum(mask) + target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) + mask = jnp.where(true_cnt == 0, False, cumsum >= target) + return fetch_first(mask, default) + + +@partial(jit, static_argnames=['reverse']) +def rank_elements(array, reverse=False): + """ + rank the element in the array. + if reverse is True, the rank is from small to large. default large to small + """ + if not reverse: + array = -array + return jnp.argsort(jnp.argsort(array)) \ No newline at end of file diff --git a/algorithm/state.py b/algorithm/state.py index d20774c..e03d99d 100644 --- a/algorithm/state.py +++ b/algorithm/state.py @@ -1,4 +1,4 @@ -from jax.tree_util import register_pytree_node_class, tree_map +from jax.tree_util import register_pytree_node_class @register_pytree_node_class @@ -20,10 +20,12 @@ class State: return f"State ({self.state_dict})" def tree_flatten(self): + print('tree_flatten_cal') children = list(self.state_dict.values()) aux_data = list(self.state_dict.keys()) return children, aux_data @classmethod def tree_unflatten(cls, aux_data, children): + print('tree_unflatten_cal') return cls(**dict(zip(aux_data, children))) diff --git a/examples/config_test.py b/examples/config_test.py new file mode 100644 index 0000000..aeb50b1 --- /dev/null +++ b/examples/config_test.py @@ -0,0 +1,4 @@ +from algorithm.config import Configer + +config = Configer.load_config() +print(config) \ No newline at end of file diff --git a/examples/state_test.py b/examples/state_test.py index 0a28590..ef2fddf 100644 --- a/examples/state_test.py +++ b/examples/state_test.py @@ -1,6 +1,8 @@ import jax +from jax import numpy as jnp from algorithm.state import State + @jax.jit def func(state: State, a): return state.update(a=a) @@ -9,6 +11,5 @@ def func(state: State, a): state = State(c=1, b=2) print(state) -state = func(state, 1111111) - -print(state) +vmap_func = jax.vmap(func, in_axes=(None, 0)) +print(vmap_func(state, jnp.array([1, 2, 3]))) \ No newline at end of file diff --git a/examples/xor_test.py b/examples/xor_test.py new file mode 100644 index 0000000..46dcf8e --- /dev/null +++ b/examples/xor_test.py @@ -0,0 +1,16 @@ +import jax + +from algorithm.config import Configer +from algorithm.neat import NEAT +from algorithm.neat.genome import create_mutate + +if __name__ == '__main__': + config = Configer.load_config() + neat = NEAT(config) + randkey = jax.random.PRNGKey(42) + state = neat.setup(randkey) + mutate_func = jax.jit(create_mutate(config, neat.gene_type)) + state = mutate_func(state) + print(state) + +