From 0a2a9fd1be28ff28b6c91f0a7abcaae2d7ff173d Mon Sep 17 00:00:00 2001 From: wls2002 Date: Tue, 18 Jul 2023 23:55:36 +0800 Subject: [PATCH] complete normal neat algorithm --- algorithm/__init__.py | 1 + algorithm/config.py | 49 ---- algorithm/default_config.ini | 17 +- algorithm/neat/NEAT.py | 50 ---- algorithm/neat/__init__.py | 4 +- algorithm/neat/gene/__init__.py | 3 + algorithm/neat/gene/activation.py | 42 ++-- algorithm/neat/gene/aggregation.py | 21 +- algorithm/neat/gene/base.py | 20 +- algorithm/neat/gene/normal.py | 122 +++++++++- algorithm/neat/genome/__init__.py | 1 + algorithm/neat/genome/basic.py | 10 +- algorithm/neat/genome/crossover.py | 12 +- algorithm/neat/genome/graph.py | 18 +- algorithm/neat/genome/mutate.py | 117 +++++---- algorithm/neat/neat.py | 75 ++++++ algorithm/neat/pipeline.py | 79 +++++++ algorithm/neat/population.py | 368 +++++++++++++++++++++++++++++ algorithm/neat/utils.py | 19 +- algorithm/state.py | 2 - examples/xor.ini | 5 + examples/xor.py | 31 +++ examples/xor_test.py | 29 ++- test/__init__.py | 0 test/unit/__init__.py | 0 test/unit/test_utils.py | 36 +++ 26 files changed, 880 insertions(+), 251 deletions(-) delete mode 100644 algorithm/neat/NEAT.py create mode 100644 algorithm/neat/neat.py create mode 100644 algorithm/neat/pipeline.py create mode 100644 algorithm/neat/population.py create mode 100644 examples/xor.ini create mode 100644 examples/xor.py create mode 100644 test/__init__.py create mode 100644 test/unit/__init__.py create mode 100644 test/unit/test_utils.py diff --git a/algorithm/__init__.py b/algorithm/__init__.py index 1703af4..a9899b1 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1,2 +1,3 @@ from .state import State from .neat import NEAT +from .config import Configer \ No newline at end of file diff --git a/algorithm/config.py b/algorithm/config.py index 4b8946b..d23db46 100644 --- a/algorithm/config.py +++ b/algorithm/config.py @@ -4,49 +4,6 @@ import configparser import numpy as np -# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. -jit_config_keys = [ - "input_idx", - "output_idx", - "compatibility_disjoint", - "compatibility_weight", - "conn_add_prob", - "conn_add_trials", - "conn_delete_prob", - "node_add_prob", - "node_delete_prob", - "compatibility_threshold", - "bias_init_mean", - "bias_init_std", - "bias_mutate_power", - "bias_mutate_rate", - "bias_replace_rate", - "response_init_mean", - "response_init_std", - "response_mutate_power", - "response_mutate_rate", - "response_replace_rate", - "activation_default", - "activation_options", - "activation_replace_rate", - "aggregation_default", - "aggregation_options", - "aggregation_replace_rate", - "weight_init_mean", - "weight_init_std", - "weight_mutate_power", - "weight_mutate_rate", - "weight_replace_rate", - "enable_mutate_rate", - "max_stagnation", - "pop_size", - "genome_elitism", - "survival_threshold", - "species_elitism", - "spawn_number_move_rate" -] - - class Configer: @classmethod @@ -110,9 +67,3 @@ class Configer: def refactor_aggregation(cls, config): config['aggregation_default'] = 0 config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) - - @classmethod - def create_jit_config(cls, config): - jit_config = {k: config[k] for k in jit_config_keys} - - return jit_config diff --git a/algorithm/default_config.ini b/algorithm/default_config.ini index 8913b75..c62905c 100644 --- a/algorithm/default_config.ini +++ b/algorithm/default_config.ini @@ -1,29 +1,26 @@ [basic] num_inputs = 2 num_outputs = 1 -maximum_nodes = 5 -maximum_connections = 5 -maximum_species = 10 +maximum_nodes = 100 +maximum_connections = 100 +maximum_species = 100 forward_way = "pop" batch_size = 4 random_seed = 0 network_type = 'feedforward' [population] -fitness_threshold = 3.99999 +fitness_threshold = 3.9999 generation_limit = 1000 fitness_criterion = "max" pop_size = 1000 -[gene] -gene_type = "normal" - [genome] compatibility_disjoint = 1.0 compatibility_weight = 0.5 -conn_add_prob = 0.4 +conn_add_prob = 0.5 conn_add_trials = 1 -conn_delete_prob = 0.4 +conn_delete_prob = 0.5 node_add_prob = 0.2 node_delete_prob = 0.2 @@ -34,7 +31,7 @@ max_stagnation = 15 genome_elitism = 2 survival_threshold = 0.2 min_species_size = 1 -spawn_number_move_rate = 0.5 +spawn_number_change_rate = 0.5 [gene-bias] bias_init_mean = 0.0 diff --git a/algorithm/neat/NEAT.py b/algorithm/neat/NEAT.py deleted file mode 100644 index f9a9343..0000000 --- a/algorithm/neat/NEAT.py +++ /dev/null @@ -1,50 +0,0 @@ -import jax - -from algorithm.state import State -from .gene import * -from .genome import initialize_genomes, create_mutate, create_distance, crossover - - -class NEAT: - def __init__(self, config): - self.config = config - if self.config['gene_type'] == 'normal': - self.gene_type = NormalGene - else: - raise NotImplementedError - - self.mutate = jax.jit(create_mutate(config, self.gene_type)) - self.distance = jax.jit(create_distance(config, self.gene_type)) - self.crossover = jax.jit(crossover) - - def setup(self, randkey): - - state = State( - randkey=randkey, - P=self.config['pop_size'], - N=self.config['maximum_nodes'], - C=self.config['maximum_connections'], - S=self.config['maximum_species'], - NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes - CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes - input_idx=self.config['input_idx'], - output_idx=self.config['output_idx'] - ) - - state = self.gene_type.setup(state, self.config) - - pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) - next_node_key = max(*state.input_idx, *state.output_idx) + 2 - state = state.update( - pop_nodes=pop_nodes, - pop_conns=pop_conns, - next_node_key=next_node_key - ) - - return state - - def tell(self, state, fitness): - return State() - - def ask(self, state): - return State() diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index b4d1a48..87eba79 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1 +1,3 @@ -from .NEAT import NEAT +from .neat import NEAT +from .gene import NormalGene +from .pipeline import Pipeline diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index aebf638..dc4c9db 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1,2 +1,5 @@ from .base import BaseGene from .normal import NormalGene +from .activation import Activation +from .aggregation import Aggregation + diff --git a/algorithm/neat/gene/activation.py b/algorithm/neat/gene/activation.py index a8075e0..ccbf655 100644 --- a/algorithm/neat/gene/activation.py +++ b/algorithm/neat/gene/activation.py @@ -3,6 +3,8 @@ import jax.numpy as jnp class Activation: + name2func = {} + @staticmethod def sigmoid_act(z): z = jnp.clip(z * 5, -60, 60) @@ -86,23 +88,23 @@ class Activation: def cube_act(z): return z ** 3 - name2func = { - 'sigmoid': sigmoid_act, - 'tanh': tanh_act, - 'sin': sin_act, - 'gauss': gauss_act, - 'relu': relu_act, - 'elu': elu_act, - 'lelu': lelu_act, - 'selu': selu_act, - 'softplus': softplus_act, - 'identity': identity_act, - 'clamped': clamped_act, - 'inv': inv_act, - 'log': log_act, - 'exp': exp_act, - 'abs': abs_act, - 'hat': hat_act, - 'square': square_act, - 'cube': cube_act, - } +Activation.name2func = { + 'sigmoid': Activation.sigmoid_act, + 'tanh': Activation.tanh_act, + 'sin': Activation.sin_act, + 'gauss': Activation.gauss_act, + 'relu': Activation.relu_act, + 'elu': Activation.elu_act, + 'lelu': Activation.lelu_act, + 'selu': Activation.selu_act, + 'softplus': Activation.softplus_act, + 'identity': Activation.identity_act, + 'clamped': Activation.clamped_act, + 'inv': Activation.inv_act, + 'log': Activation.log_act, + 'exp': Activation.exp_act, + 'abs': Activation.abs_act, + 'hat': Activation.hat_act, + 'square': Activation.square_act, + 'cube': Activation.cube_act, +} diff --git a/algorithm/neat/gene/aggregation.py b/algorithm/neat/gene/aggregation.py index c03d960..be85ca4 100644 --- a/algorithm/neat/gene/aggregation.py +++ b/algorithm/neat/gene/aggregation.py @@ -3,6 +3,8 @@ import jax.numpy as jnp class Aggregation: + name2func = {} + @staticmethod def sum_agg(z): z = jnp.where(jnp.isnan(z), 0, z) @@ -49,12 +51,13 @@ class Aggregation: mean_without_zeros = valid_values_sum / valid_values_count return mean_without_zeros - name2func = { - 'sum': sum_agg, - 'product': product_agg, - 'max': max_agg, - 'min': min_agg, - 'maxabs': maxabs_agg, - 'median': median_agg, - 'mean': mean_agg, - } + +Aggregation.name2func = { + 'sum': Aggregation.sum_agg, + 'product': Aggregation.product_agg, + 'max': Aggregation.max_agg, + 'min': Aggregation.min_agg, + 'maxabs': Aggregation.maxabs_agg, + 'median': Aggregation.median_agg, + 'mean': Aggregation.mean_agg, +} \ No newline at end of file diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py index 7a710f2..9036e65 100644 --- a/algorithm/neat/gene/base.py +++ b/algorithm/neat/gene/base.py @@ -1,4 +1,4 @@ -from jax import Array, numpy as jnp +from jax import Array, numpy as jnp, vmap class BaseGene: @@ -26,13 +26,19 @@ class BaseGene: return attrs @staticmethod - def distance_node(state, array1: Array, array2: Array): - return array1 + def distance_node(state, node1: Array, node2: Array): + return node1 @staticmethod - def distance_conn(state, array1: Array, array2: Array): - return array1 + def distance_conn(state, conn1: Array, conn2: Array): + return conn1 + @staticmethod - def forward(state, array: Array): - return array + def forward_transform(nodes, conns): + return nodes, conns + + + @staticmethod + def create_forward(config): + return None \ No newline at end of file diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 7468671..c5bcd97 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -1,7 +1,11 @@ import jax from jax import Array, numpy as jnp -from . import BaseGene +from .base import BaseGene +from .activation import Activation +from .aggregation import Aggregation +from ..utils import unflatten_connections, I_INT +from ..genome import topological_sort class NormalGene(BaseGene): @@ -70,18 +74,116 @@ class NormalGene(BaseGene): return jnp.array([weight]) @staticmethod - def distance_node(state, array1: Array, array2: Array): + def distance_node(state, node1: Array, node2: Array): # bias + response + activation + aggregation - return jnp.abs(array1[1] - array2[1]) + jnp.abs(array1[2] - array2[2]) + \ - (array1[3] != array2[3]) + (array1[4] != array2[4]) + return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \ + (node1[3] != node2[3]) + (node1[4] != node2[4]) @staticmethod - def distance_conn(state, array1: Array, array2: Array): - return (array1[2] != array2[2]) + jnp.abs(array1[3] - array2[3]) # enable + weight + def distance_conn(state, con1: Array, con2: Array): + return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight @staticmethod - def forward(state, array: Array): - return array + def forward_transform(nodes, conns): + u_conns = unflatten_connections(nodes, conns) + u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan + u_conns = u_conns[1:, :] # remove enable attr + conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0) + seqs = topological_sort(nodes, conn_exist) + return seqs, nodes, u_conns + + @staticmethod + def create_forward(config): + config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']] + config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']] + + def act(idx, z): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + res = jax.lax.switch(idx, config['activation_funcs'], z) + return res + + def agg(idx, z): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + + def all_nan(): + return 0. + + def not_all_nan(): + return jax.lax.switch(idx, config['aggregation_funcs'], z) + + return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) + + def forward(inputs, transform) -> Array: + """ + jax forward for single input shaped (input_num, ) + nodes, connections are a single genome + + :argument inputs: (input_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (output_num, ) + """ + + cal_seqs, nodes, cons = transform + + input_idx = config['input_idx'] + output_idx = config['output_idx'] + + N = nodes.shape[0] + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[input_idx].set(inputs) + + weights = cons[0, :] + + def cond_fun(carry): + values, idx = carry + return (idx < N) & (cal_seqs[idx] != I_INT) + + def body_func(carry): + values, idx = carry + i = cal_seqs[idx] + + def hit(): + ins = values * weights[:, i] + z = agg(nodes[i, 4], ins) # z = agg(ins) + z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias + z = act(nodes[i, 3], z) # z = act(z) + + new_values = values.at[i].set(z) + return new_values + + def miss(): + return values + + # the val of input nodes is obtained by the task, not by calculation + values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) + + # if jnp.isin(i, input_idx): + # values = miss() + # else: + # values = hit() + + return values, idx + 1 + + # carry = (ini_vals, 0) + # while cond_fun(carry): + # carry = body_func(carry) + # vals, _ = carry + + vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) + + return vals[output_idx] + + return forward @staticmethod def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): @@ -114,3 +216,7 @@ class NormalGene(BaseGene): ) return val + + + + diff --git a/algorithm/neat/genome/__init__.py b/algorithm/neat/genome/__init__.py index 1b859ae..ec0b7b9 100644 --- a/algorithm/neat/genome/__init__.py +++ b/algorithm/neat/genome/__init__.py @@ -2,3 +2,4 @@ from .basic import initialize_genomes from .mutate import create_mutate from .distance import create_distance from .crossover import crossover +from .graph import topological_sort \ No newline at end of file diff --git a/algorithm/neat/genome/basic.py b/algorithm/neat/genome/basic.py index 915eebc..a71cca8 100644 --- a/algorithm/neat/genome/basic.py +++ b/algorithm/neat/genome/basic.py @@ -37,9 +37,17 @@ def initialize_genomes(state: State, gene_type: Type[BaseGene]): pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) pop_conns = np.tile(o_conns, (state.P, 1, 1)) - return pop_nodes, pop_conns + return jax.device_put([pop_nodes, pop_conns]) +def count(nodes: Array, cons: Array): + """ + Count how many nodes and connections are in the genome. + """ + node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) + cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) + return node_cnt, cons_cnt + def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]: """ Add a new node to the genome. diff --git a/algorithm/neat/genome/crossover.py b/algorithm/neat/genome/crossover.py index 44ce594..d61448f 100644 --- a/algorithm/neat/genome/crossover.py +++ b/algorithm/neat/genome/crossover.py @@ -4,12 +4,12 @@ import jax from jax import jit, Array, numpy as jnp -def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array): +def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Array): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ - randkey_1, randkey_2, key= jax.random.split(state.randkey, 3) + randkey_1, randkey_2, key= jax.random.split(randkey, 3) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -21,11 +21,11 @@ def crossover(state, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array): new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections - con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2] - cons2 = align_array(con_keys1, con_keys2, cons2, True) - new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2)) + con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] + cons2 = align_array(con_keys1, con_keys2, conns2, True) + new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(cons2), conns1, crossover_gene(randkey_2, conns1, cons2)) - return state.update(randkey=key), new_nodes, new_cons + return new_nodes, new_cons def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array: diff --git a/algorithm/neat/genome/graph.py b/algorithm/neat/genome/graph.py index 79d1810..8fc9842 100644 --- a/algorithm/neat/genome/graph.py +++ b/algorithm/neat/genome/graph.py @@ -9,12 +9,11 @@ from jax import jit, Array, numpy as jnp from ..utils import fetch_first, I_INT -@jit -def topological_sort(nodes: Array, connections: Array) -> Array: +def topological_sort(nodes: Array, conns: Array) -> Array: """ a jit-able version of topological_sort! that's crazy! :param nodes: nodes array - :param connections: connections array + :param conns: connections array :return: topological sorted sequence Example: @@ -25,12 +24,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array: [3] ]) connections = jnp.array([ - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ], [ [0, 0, 1, 0], [0, 0, 1, 1], @@ -41,8 +34,8 @@ def topological_sort(nodes: Array, connections: Array) -> Array: topological_sort(nodes, connections) -> [0, 1, 2, 3] """ - connections_enable = connections[1, :, :] == 1 # forward function. thus use enable - in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0)) + + in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0)) res = jnp.full(in_degree.shape, I_INT) def cond_fun(carry): @@ -59,7 +52,7 @@ def topological_sort(nodes: Array, connections: Array) -> Array: in_degree_ = in_degree_.at[i].set(-1) # decrease in_degree of all its children - children = connections_enable[i, :] + children = conns[i, :] in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_) return res_, idx_ + 1, in_degree_ @@ -67,7 +60,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array: return res -@jit def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: """ Check whether a new connection (from_idx -> to_idx) will cause a cycle. diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py index 98e9d93..c555c39 100644 --- a/algorithm/neat/genome/mutate.py +++ b/algorithm/neat/genome/mutate.py @@ -4,7 +4,7 @@ import jax from jax import Array, numpy as jnp, vmap from algorithm import State -from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx +from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count from .graph import check_cycles from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections from ..gene import BaseGene @@ -12,46 +12,51 @@ from ..gene import BaseGene def create_mutate(config: Dict, gene_type: Type[BaseGene]): """ - Create function to mutate the whole population + Create function to mutate a single genome """ - def mutate_structure(state: State, randkey, nodes, cons, new_node_key): - def nothing(*args): - return nodes, cons + def mutate_structure(state: State, randkey, nodes, conns, new_node_key): - def mutate_add_node(key_): - i_key, o_key, idx = choice_connection_key(key_, nodes, cons) + def mutate_add_node(key_, nodes_, conns_): + i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_) + + def nothing(): + return nodes_, conns_ def successful_add_node(): # disable the connection - aux_nodes, aux_cons = nodes, cons + aux_nodes, aux_conns = nodes_, conns_ # set enable to false - aux_cons = aux_cons.at[idx, 2].set(False) + aux_conns = aux_conns.at[idx, 2].set(False) # add a new node - aux_nodes, aux_cons = add_node(aux_nodes, aux_cons, new_node_key, gene_type.new_node_attrs(state)) + aux_nodes, aux_conns = add_node(aux_nodes, aux_conns, new_node_key, gene_type.new_node_attrs(state)) # add two new connections - aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, i_key, new_node_key, True, + aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, i_key, new_node_key, True, gene_type.new_conn_attrs(state)) - aux_nodes, aux_cons = add_connection(aux_nodes, aux_cons, new_node_key, o_key, True, + aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, new_node_key, o_key, True, gene_type.new_conn_attrs(state)) - return aux_nodes, aux_cons + return aux_nodes, aux_conns # if from_idx == I_INT, that means no connection exist, do nothing - return jax.lax.cond(idx == I_INT, nothing, successful_add_node) + new_nodes, new_conns = jax.lax.cond(idx == I_INT, nothing, successful_add_node) - def mutate_delete_node(key_): + return new_nodes, new_conns + + def mutate_delete_node(key_, nodes_, conns_): # TODO: Do we really need to delete a node? # randomly choose a node - key, idx = choice_node_key(key_, nodes, config['input_idx'], config['output_idx'], + key, idx = choice_node_key(key_, nodes_, config['input_idx'], config['output_idx'], allow_input_keys=False, allow_output_keys=False) + def nothing(): + return nodes_, conns_ def successful_delete_node(): # delete the node - aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx) + aux_nodes, aux_cons = delete_node_by_idx(nodes_, conns_, idx) # delete all connections aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], @@ -61,29 +66,32 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]): return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - def mutate_add_conn(key_): + def mutate_add_conn(key_, nodes_, conns_): # randomly choose two nodes k1_, k2_ = jax.random.split(key_, num=2) - i_key, from_idx = choice_node_key(k1_, nodes, config['input_idx'], config['output_idx'], + i_key, from_idx = choice_node_key(k1_, nodes_, config['input_idx'], config['output_idx'], allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2_, nodes, config['input_idx'], config['output_idx'], + o_key, to_idx = choice_node_key(k2_, nodes_, config['input_idx'], config['output_idx'], allow_input_keys=False, allow_output_keys=True) - con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) + con_idx = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)) + + def nothing(): + return nodes_, conns_ def successful(): - new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, True, gene_type.new_conn_attrs(state)) + new_nodes, new_cons = add_connection(nodes_, conns_, i_key, o_key, True, gene_type.new_conn_attrs(state)) return new_nodes, new_cons def already_exist(): - new_cons = cons.at[con_idx, 2].set(True) - return nodes, new_cons + new_cons = conns_.at[con_idx, 2].set(True) + return nodes_, new_cons is_already_exist = con_idx != I_INT if config['network_type'] == 'feedforward': - u_cons = unflatten_connections(nodes, cons) - is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx) + u_cons = unflatten_connections(nodes_, conns_) + is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) return jax.lax.switch(choice, [already_exist, nothing, successful]) @@ -94,23 +102,33 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]): else: raise ValueError(f"Invalid network type: {config['network_type']}") - def mutate_delete_conn(key_): + def mutate_delete_conn(key_, nodes_, conns_): # randomly choose a connection - i_key, o_key, idx = choice_connection_key(key_, nodes, cons) + i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_) + + def nothing(): + return nodes_, conns_ def successfully_delete_connection(): - return delete_connection_by_idx(nodes, cons, idx) + return delete_connection_by_idx(nodes_, conns_, idx) return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - k, k1, k2, k3, k4 = jax.random.split(randkey, num=5) + k1, k2, k3, k4 = jax.random.split(randkey, num=4) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) - nodes, cons = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, nothing, k1) - nodes, cons = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, nothing, k2) - nodes, cons = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, nothing, k3) - nodes, cons = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, nothing, k4) - return nodes, cons + def no(k, n, c): + return n, c + + nodes, conns = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, no, k1, nodes, conns) + + nodes, conns = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, no, k2, nodes, conns) + + nodes, conns = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, no, k3, nodes, conns) + + nodes, conns = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, no, k4, nodes, conns) + + return nodes, conns def mutate_values(state: State, randkey, nodes, conns): k1, k2 = jax.random.split(randkey, num=2) @@ -131,32 +149,13 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]): return new_nodes, new_conns - def mutate(state): - pop_nodes, pop_conns = state.pop_nodes, state.pop_conns - pop_size = pop_nodes.shape[0] + def mutate(state, randkey, nodes, conns, new_node_key): + k1, k2 = jax.random.split(randkey) - new_node_keys = jnp.arange(pop_size) + state.next_node_key - k1, k2, randkey = jax.random.split(state.randkey, num=3) - structure_randkeys = jax.random.split(k1, num=pop_size) - values_randkeys = jax.random.split(k2, num=pop_size) + nodes, conns = mutate_structure(state, k1, nodes, conns, new_node_key) + nodes, conns = mutate_values(state, k2, nodes, conns) - structure_func = jax.vmap(mutate_structure, in_axes=(None, 0, 0, 0, 0)) - pop_nodes, pop_conns = structure_func(state, structure_randkeys, pop_nodes, pop_conns, new_node_keys) - - values_func = jax.vmap(mutate_values, in_axes=(None, 0, 0, 0)) - pop_nodes, pop_conns = values_func(state, values_randkeys, pop_nodes, pop_conns) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) - next_node_key = max_node_key + 1 - - return state.update( - pop_nodes=pop_nodes, - pop_conns=pop_conns, - next_node_key=next_node_key, - randkey=randkey - ) + return nodes, conns return mutate diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py new file mode 100644 index 0000000..044377a --- /dev/null +++ b/algorithm/neat/neat.py @@ -0,0 +1,75 @@ +from typing import Type + +import jax +import jax.numpy as jnp + +from algorithm.state import State +from .gene import BaseGene +from .genome import initialize_genomes, create_mutate, create_distance, crossover +from .population import create_tell + + +class NEAT: + def __init__(self, config, gene_type: Type[BaseGene]): + self.config = config + self.gene_type = gene_type + + self.mutate = jax.jit(create_mutate(config, self.gene_type)) + self.distance = jax.jit(create_distance(config, self.gene_type)) + self.crossover = jax.jit(crossover) + self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform)) + self.forward = jax.jit(self.gene_type.create_forward(config)) + self.tell_func = jax.jit(create_tell(config, self.gene_type)) + + def setup(self, randkey): + + state = State( + P=self.config['pop_size'], + N=self.config['maximum_nodes'], + C=self.config['maximum_connections'], + S=self.config['maximum_species'], + NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes + CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes + input_idx=self.config['input_idx'], + output_idx=self.config['output_idx'], + max_stagnation=self.config['max_stagnation'], + species_elitism=self.config['species_elitism'], + spawn_number_change_rate=self.config['spawn_number_change_rate'], + genome_elitism=self.config['genome_elitism'], + survival_threshold=self.config['survival_threshold'], + compatibility_threshold=self.config['compatibility_threshold'], + ) + + state = self.gene_type.setup(state, self.config) + + randkey = randkey + pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) + species_info = jnp.full((state.S, 4), jnp.nan, + dtype=jnp.float32) # (species_key, best_fitness, last_improved, size) + species_info = species_info.at[0, :].set([0, -jnp.inf, 0, state.P]) + idx2species = jnp.zeros(state.P, dtype=jnp.float32) + center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32) + center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32) + center_nodes = center_nodes.at[0, :, :].set(pop_nodes[0, :, :]) + center_conns = center_conns.at[0, :, :].set(pop_conns[0, :, :]) + generation = 0 + next_node_key = max(*state.input_idx, *state.output_idx) + 2 + next_species_key = 1 + + state = state.update( + randkey=randkey, + pop_nodes=pop_nodes, + pop_conns=pop_conns, + species_info=species_info, + idx2species=idx2species, + center_nodes=center_nodes, + center_conns=center_conns, + generation=generation, + next_node_key=next_node_key, + next_species_key=next_species_key + ) + + return state + + def step(self, state, fitness): + return self.tell_func(state, fitness) diff --git a/algorithm/neat/pipeline.py b/algorithm/neat/pipeline.py new file mode 100644 index 0000000..fba0391 --- /dev/null +++ b/algorithm/neat/pipeline.py @@ -0,0 +1,79 @@ +import time +from typing import Union, Callable + +import jax +from jax import vmap, jit +import numpy as np + +class Pipeline: + """ + Neat algorithm pipeline. + """ + + def __init__(self, config, algorithm): + self.config = config + self.algorithm = algorithm + randkey = jax.random.PRNGKey(config['random_seed']) + self.state = algorithm.setup(randkey) + + self.best_genome = None + self.best_fitness = float('-inf') + self.generation_timestamp = time.time() + + self.evaluate_time = 0 + + self.forward_func = algorithm.gene_type.create_forward(config) + self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) + self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) + + self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform)) + + def ask(self): + pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns) + return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) + + def tell(self, fitness): + self.state = self.algorithm.step(self.state, fitness) + from algorithm.neat.genome.basic import count + # print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)]) + + + def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): + for _ in range(self.config['generation_limit']): + forward_func = self.ask() + + fitnesses = fitness_func(forward_func) + + if analysis is not None: + if analysis == "default": + self.default_analysis(fitnesses) + else: + assert callable(analysis), f"What the fuck you passed in? A {analysis}?" + analysis(fitnesses) + + if max(fitnesses) >= self.config['fitness_threshold']: + print("Fitness limit reached!") + return self.best_genome + + self.tell(fitnesses) + print("Generation limit reached!") + return self.best_genome + + def default_analysis(self, fitnesses): + max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) + + new_timestamp = time.time() + cost_time = new_timestamp - self.generation_timestamp + self.generation_timestamp = new_timestamp + + max_idx = np.argmax(fitnesses) + if fitnesses[max_idx] > self.best_fitness: + self.best_fitness = fitnesses[max_idx] + self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx]) + + member_count = jax.device_get(self.state.species_info[:, 3]) + species_sizes = [int(i) for i in member_count if i > 0] + + print(f"Generation: {self.state.generation}", + f"species: {len(species_sizes)}, {species_sizes}", + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") \ No newline at end of file diff --git a/algorithm/neat/population.py b/algorithm/neat/population.py new file mode 100644 index 0000000..4705834 --- /dev/null +++ b/algorithm/neat/population.py @@ -0,0 +1,368 @@ +from typing import Type + +import jax +from jax import numpy as jnp, vmap + +from .utils import rank_elements, fetch_first +from .genome import create_mutate, create_distance, crossover +from .gene import BaseGene + +def create_tell(config, gene_type: Type[BaseGene]): + + mutate = create_mutate(config, gene_type) + distance = create_distance(config, gene_type) + + def update_species(state, randkey, fitness): + # update the fitness of each species + species_fitness = update_species_fitness(state, fitness) + + # stagnation species + state, species_fitness = stagnation(state, species_fitness) + + # sort species_info by their fitness. (push nan to the end) + sort_indices = jnp.argsort(species_fitness)[::-1] + + state = state.update( + species_info=state.species_info[sort_indices], + center_nodes=state.center_nodes[sort_indices], + center_conns=state.center_conns[sort_indices], + ) + + # decide the number of members of each species by their fitness + spawn_number = cal_spawn_numbers(state) + + # crossover info + winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness) + + return state, winner, loser, elite_mask + + + def update_species_fitness(state, fitness): + """ + obtain the fitness of the species by the fitness of each individual. + use max criterion. + """ + + def aux_func(idx): + species_key = state.species_info[idx, 0] + s_fitness = jnp.where(state.idx2species == species_key, fitness, -jnp.inf) + f = jnp.max(s_fitness) + return f + + return vmap(aux_func)(jnp.arange(state.species_info.shape[0])) + + + def stagnation(state, species_fitness): + """ + stagnation species. + those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. + elitism species never stagnation + """ + + def aux_func(idx): + s_fitness = species_fitness[idx] + species_key, best_score, last_update, members_count = state.species_info[idx] + st = (s_fitness <= best_score) & (state.generation - last_update > state.max_stagnation) + last_update = jnp.where(s_fitness > best_score, state.generation, last_update) + best_score = jnp.where(s_fitness > best_score, s_fitness, best_score) + # stagnation condition + return st, jnp.array([species_key, best_score, last_update, members_count]) + + spe_st, species_info = vmap(aux_func)(jnp.arange(species_fitness.shape[0])) + + # elite species will not be stagnation + species_rank = rank_elements(species_fitness) + spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation + + # set stagnation species to nan + species_info = jnp.where(spe_st[:, None], jnp.nan, species_info) + center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_nodes) + center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_conns) + species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) + + state = state.update( + species_info=species_info, + center_nodes=center_nodes, + center_conns=center_conns, + ) + + return state, species_fitness + + + def cal_spawn_numbers(state): + """ + decide the number of members of each species by their fitness rank. + the species with higher fitness will have more members + Linear ranking selection + e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] + """ + + is_species_valid = ~jnp.isnan(state.species_info[:, 0]) + valid_species_num = jnp.sum(is_species_valid) + denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + + rank_score = valid_species_num - jnp.arange(state.species_info.shape[0]) # obtain [3, 2, 1] + spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] + spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + + target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member + # jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number) + + # Avoid too much variation of numbers in a species + previous_size = state.species_info[:, 3].astype(jnp.int32) + spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate + # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) + spawn_number = spawn_number.astype(jnp.int32) + + # spawn_number = target_spawn_number.astype(jnp.int32) + + # must control the sum of spawn_number to be equal to pop_size + error = state.P - jnp.sum(spawn_number) + spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number + + return spawn_number + + + def create_crossover_pair(state, randkey, spawn_number, fitness): + species_size = state.species_info.shape[0] + pop_size = fitness.shape[0] + s_idx = jnp.arange(species_size) + p_idx = jnp.arange(pop_size) + + # def aux_func(key, idx): + def aux_func(key, idx): + members = state.idx2species == state.species_info[idx, 0] + members_num = jnp.sum(members) + + members_fitness = jnp.where(members, fitness, -jnp.inf) + sorted_member_indices = jnp.argsort(members_fitness)[::-1] + + elite_size = state.genome_elitism + survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32) + + select_pro = (p_idx < survive_size) / survive_size + fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) + + # elite + fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) + ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) + elite = jnp.where(p_idx < elite_size, True, False) + return fa, ma, elite + + fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) + + spawn_number_cum = jnp.cumsum(spawn_number) + + def aux_func(idx): + loc = jnp.argmax(idx < spawn_number_cum) + + # elite genomes are at the beginning of the species + idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) + return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + + part1, part2, elite_mask = vmap(aux_func)(p_idx) + + is_part1_win = fitness[part1] >= fitness[part2] + winner = jnp.where(is_part1_win, part1, part2) + loser = jnp.where(is_part1_win, part2, part1) + + return winner, loser, elite_mask + + def create_next_generation(state, randkey, winner, loser, elite_mask): + # prepare random keys + pop_size = state.pop_nodes.shape[0] + new_node_keys = jnp.arange(pop_size) + state.next_node_key + + k1, k2 = jax.random.split(randkey, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] # winner pop nodes, winner pop connections + lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] # loser pop nodes, loser pop connections + npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections + + # batch mutation + mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0, 0)) + m_npn, m_npc = mutate_func(state, mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) + pop_conns = jnp.where(elite_mask[:, None, None], npc, m_npc) + + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return state.update( + pop_nodes=pop_nodes, + pop_conns=pop_conns, + next_node_key=next_node_key, + ) + + def speciate(state): + pop_size, species_size = state.pop_nodes.shape[0], state.center_nodes.shape[0] + + # prepare distance functions + o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0, 0)) # one to population + + # idx to specie key + idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species + + # the distance between genomes to its center genomes + o2c_distances = jnp.full((pop_size,), jnp.inf) + + # step 1: find new centers + def cond_func(carry): + i, i2s, cn, cc, o2c = carry + species_key = state.species_info[i, 0] + # jax.debug.print("{}, {}", i, species_key) + return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing + + def body_func(carry): + i, i2s, cn, cc, o2c = carry + distances = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns) + + # find the closest one + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) + # jax.debug.print("closest_idx: {}", closest_idx) + + i2s = i2s.at[closest_idx].set(state.species_info[i, 0]) + cn = cn.at[i].set(state.pop_nodes[closest_idx]) + cc = cc.at[i].set(state.pop_conns[closest_idx]) + + # the genome with closest_idx will become the new center, thus its distance to center is 0. + o2c = o2c.at[closest_idx].set(0) + + return i + 1, i2s, cn, cc, o2c + + _, idx2specie, center_nodes, center_conns, o2c_distances = \ + jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances)) + + + # part 2: assign members to each species + def cond_func(carry): + i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key + current_species_existed = ~jnp.isnan(si[i, 0]) + not_all_assigned = jnp.any(jnp.isnan(i2s)) + not_reach_species_upper_bounds = i < species_size + return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + + def body_func(carry): + i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_conns + + _, i2s, scn, scc, si, o2c, nsk = jax.lax.cond( + jnp.isnan(si[i, 0]), # whether the current species is existing or not + create_new_species, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cn, cc, si, o2c, nsk) + ) + + return i + 1, i2s, scn, scc, si, o2c, nsk + + def create_new_species(carry): + i, i2s, cn, cc, si, o2c, nsk = carry + + # pick the first one who has not been assigned to any species + idx = fetch_first(jnp.isnan(i2s)) + + # assign it to the new species + # [key, best score, last update generation, members_count] + si = si.at[i].set(jnp.array([nsk, -jnp.inf, state.generation, 0])) + i2s = i2s.at[idx].set(nsk) + o2c = o2c.at[idx].set(0) + + # update center genomes + cn = cn.at[i].set(state.pop_nodes[idx]) + cc = cc.at[i].set(state.pop_conns[idx]) + + i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) + + # when a new species is created, it needs to be updated, thus do not change i + return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, cn, cc, si, o2c, nsk = carry + i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) + + # turn to next species + return i + 1, i2s, cn, cc, si, o2c, nsk + + def speciate_by_threshold(carry): + i, i2s, cn, cc, si, o2c = carry + + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns) + close_enough_mask = o2p_distance < state.compatibility_threshold + + # when a genome is not assigned or the distance between its current center is bigger than this center + cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) + # jax.debug.print("{}", o2p_distance) + mask = close_enough_mask & cacheable_mask + + # update species info + i2s = jnp.where(mask, si[i, 0], i2s) + + # update distance between centers + o2c = jnp.where(mask, o2p_distance, o2c) + + return i2s, o2c + + # update idx2specie + _, idx2specie, center_nodes, center_conns, species_info, _, next_species_key = jax.lax.while_loop( + cond_func, + body_func, + (0, idx2specie, center_nodes, center_conns, state.species_info, o2c_distances, state.next_species_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition can only happen when the number of species is reached species upper bounds + idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie) + + # update members count + def count_members(idx): + key = species_info[idx, 0] + count = jnp.sum(idx2specie == key) + count = jnp.where(jnp.isnan(key), jnp.nan, count) + return count + + species_member_counts = vmap(count_members)(jnp.arange(species_size)) + species_info = species_info.at[:, 3].set(species_member_counts) + + return state.update( + idx2specie=idx2specie, + center_nodes=center_nodes, + center_conns=center_conns, + species_info=species_info, + next_species_key=next_species_key + ) + + def tell(state, fitness): + """ + Main update function in NEAT. + """ + + k1, k2, randkey = jax.random.split(state.randkey, 3) + + state = state.update( + generation=state.generation + 1, + randkey=randkey + ) + + state, winner, loser, elite_mask = update_species(state, k1, fitness) + + state = create_next_generation(state, k2, winner, loser, elite_mask) + + state = speciate(state) + + return state + + + return tell + + +def argmin_with_mask(arr, mask): + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx \ No newline at end of file diff --git a/algorithm/neat/utils.py b/algorithm/neat/utils.py index f9b5364..a64735c 100644 --- a/algorithm/neat/utils.py +++ b/algorithm/neat/utils.py @@ -10,24 +10,25 @@ EMPTY_CON = np.full((1, 4), jnp.nan) @jit -def unflatten_connections(nodes: Array, cons: Array): +def unflatten_connections(nodes: Array, conns: Array): """ - transform the (C, 4) connections to (2, N, N) - :param nodes: (N, 5) - :param cons: (C, 4) + transform the (C, CL) connections to (CL-2, N, N) + :param nodes: (N, NL) + :param cons: (C, CL) :return: """ N = nodes.shape[0] + CL = conns.shape[1] node_keys = nodes[:, 0] - i_keys, o_keys = cons[:, 0], cons[:, 1] + i_keys, o_keys = conns[:, 0], conns[:, 1] i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) - res = jnp.full((2, N, N), jnp.nan) + res = jnp.full((CL - 2, N, N), jnp.nan) # Is interesting that jax use clip when attach data in array # however, it will do nothing set values in an array - res = res.at[0, i_idxs, o_idxs].set(cons[:, 2]) - res = res.at[1, i_idxs, o_idxs].set(cons[:, 3]) + # put all attributes include enable in res + res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T) return res @@ -68,4 +69,4 @@ def rank_elements(array, reverse=False): """ if not reverse: array = -array - return jnp.argsort(jnp.argsort(array)) \ No newline at end of file + return jnp.argsort(jnp.argsort(array)) diff --git a/algorithm/state.py b/algorithm/state.py index e03d99d..b24ff62 100644 --- a/algorithm/state.py +++ b/algorithm/state.py @@ -20,12 +20,10 @@ class State: return f"State ({self.state_dict})" def tree_flatten(self): - print('tree_flatten_cal') children = list(self.state_dict.values()) aux_data = list(self.state_dict.keys()) return children, aux_data @classmethod def tree_unflatten(cls, aux_data, children): - print('tree_unflatten_cal') return cls(**dict(zip(aux_data, children))) diff --git a/examples/xor.ini b/examples/xor.ini new file mode 100644 index 0000000..893fff7 --- /dev/null +++ b/examples/xor.ini @@ -0,0 +1,5 @@ +[basic] +forward_way = "common" + +[population] +fitness_threshold = 4 \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py new file mode 100644 index 0000000..94a3b8b --- /dev/null +++ b/examples/xor.py @@ -0,0 +1,31 @@ +import jax +import numpy as np + +from algorithm import Configer, NEAT +from algorithm.neat import NormalGene, Pipeline + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + # print(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +def main(): + config = Configer.load_config("xor.ini") + algorithm = NEAT(config, NormalGene) + pipeline = Pipeline(config, algorithm) + pipeline.auto_run(evaluate) + + +if __name__ == '__main__': + main() diff --git a/examples/xor_test.py b/examples/xor_test.py index bb2931a..9ca0c40 100644 --- a/examples/xor_test.py +++ b/examples/xor_test.py @@ -1,17 +1,32 @@ import jax +import numpy as np from algorithm.config import Configer -from algorithm.neat import NEAT +from algorithm.neat import NEAT, NormalGene, Pipeline +from algorithm.neat.genome import create_mutate + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) + +def single_genome(func, nodes, conns): + t = NormalGene.forward_transform(nodes, conns) + out1 = func(xor_inputs[0], t) + out2 = func(xor_inputs[1], t) + out3 = func(xor_inputs[2], t) + out4 = func(xor_inputs[3], t) + print(out1, out2, out3, out4) if __name__ == '__main__': config = Configer.load_config() - neat = NEAT(config) + neat = NEAT(config, NormalGene) randkey = jax.random.PRNGKey(42) state = neat.setup(randkey) - state = neat.mutate(state) - print(state) - pop_nodes, pop_conns = state.pop_nodes, state.pop_conns - print(neat.distance(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1])) - print(neat.crossover(state, pop_nodes[0], pop_conns[0], pop_nodes[1], pop_conns[1])) + forward_func = NormalGene.create_forward(config) + mutate_func = create_mutate(config, NormalGene) + + + nodes, conns = state.pop_nodes[0], state.pop_conns[0] + single_genome(forward_func, nodes, conns) + nodes, conns = mutate_func(state, randkey, nodes, conns, 10000) + single_genome(forward_func, nodes, conns) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py new file mode 100644 index 0000000..b66c1aa --- /dev/null +++ b/test/unit/test_utils.py @@ -0,0 +1,36 @@ +import pytest +import jax + +from algorithm.neat.utils import * + + +def test_unflatten(): + nodes = jnp.array([ + [0, 0, 0, 0], + [1, 1, 1, 1], + [2, 2, 2, 2], + [3, 3, 3, 3], + [jnp.nan, jnp.nan, jnp.nan, jnp.nan] + ]) + + + conns = jnp.array([ + [0, 1, True, 0.1, 0.11], + [0, 2, False, 0.2, 0.22], + [1, 2, True, 0.3, 0.33], + [1, 3, False, 0.4, 0.44], + ]) + + res = unflatten_connections(nodes, conns) + + assert jnp.all(res[:, 0, 1] == jnp.array([True, 0.1, 0.11])) + assert jnp.all(res[:, 0, 2] == jnp.array([False, 0.2, 0.22])) + assert jnp.all(res[:, 1, 2] == jnp.array([True, 0.3, 0.33])) + assert jnp.all(res[:, 1, 3] == jnp.array([False, 0.4, 0.44])) + + # Create a mask that excludes the indices we've already checked + mask = jnp.ones(res.shape, dtype=bool) + mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False) + + # Ensure all other places are jnp.nan + assert jnp.all(jnp.isnan(res[mask])) \ No newline at end of file