From 1499e062fef34c240697e13090a5465df62e5474 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 2 Aug 2023 13:26:01 +0800 Subject: [PATCH] remove create_func.... --- algorithm/__init__.py | 3 +- algorithm/hyper_neat/__init__.py | 2 - algorithm/hyper_neat/hyper_neat.py | 122 ------------ algorithm/hyper_neat/substrate/__init__.py | 2 - algorithm/hyper_neat/substrate/normal.py | 25 --- algorithm/hyper_neat/substrate/tools.py | 50 ----- algorithm/neat/__init__.py | 1 - algorithm/neat/ga/__init__.py | 3 +- algorithm/neat/ga/crossover.py | 4 +- algorithm/neat/ga/mutate.py | 221 ++++++++++----------- algorithm/neat/ga/operation.py | 40 ++++ algorithm/neat/gene/__init__.py | 1 - algorithm/neat/gene/normal.py | 168 ++++++---------- algorithm/neat/gene/recurrent.py | 84 -------- algorithm/neat/neat.py | 117 +++-------- algorithm/neat/species/__init__.py | 2 +- algorithm/neat/species/distance.py | 106 +++++----- algorithm/neat/species/operations.py | 221 ++++++++++----------- algorithm/neat/species/species_info.py | 2 +- config/__init__.py | 3 +- config/config.py | 1 + config/default_config.ini | 76 ------- core/__init__.py | 2 +- core/algorithm.py | 44 +++- core/gene.py | 51 ++--- core/genome.py | 1 - examples/test.py | 24 +++ examples/xor.py | 11 +- examples/xor_hyperNEAT.py | 49 ----- examples/xor_recurrent.py | 39 ---- pipeline.py | 14 +- utils/__init__.py | 37 +++- utils/activation.py | 30 +-- utils/aggregation.py | 24 ++- 34 files changed, 558 insertions(+), 1022 deletions(-) delete mode 100644 algorithm/hyper_neat/__init__.py delete mode 100644 algorithm/hyper_neat/hyper_neat.py delete mode 100644 algorithm/hyper_neat/substrate/__init__.py delete mode 100644 algorithm/hyper_neat/substrate/normal.py delete mode 100644 algorithm/hyper_neat/substrate/tools.py create mode 100644 algorithm/neat/ga/operation.py delete mode 100644 algorithm/neat/gene/recurrent.py delete mode 100644 config/default_config.ini create mode 100644 examples/test.py delete mode 100644 examples/xor_hyperNEAT.py delete mode 100644 examples/xor_recurrent.py diff --git a/algorithm/__init__.py b/algorithm/__init__.py index 68e966c..6fe56c9 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1,2 +1 @@ -from .neat import * -from .hyper_neat import * \ No newline at end of file +from .neat import NEAT diff --git a/algorithm/hyper_neat/__init__.py b/algorithm/hyper_neat/__init__.py deleted file mode 100644 index 4227bc4..0000000 --- a/algorithm/hyper_neat/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .hyper_neat import HyperNEAT -from .substrate import NormalSubstrate, NormalSubstrateConfig \ No newline at end of file diff --git a/algorithm/hyper_neat/hyper_neat.py b/algorithm/hyper_neat/hyper_neat.py deleted file mode 100644 index 6008489..0000000 --- a/algorithm/hyper_neat/hyper_neat.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Type - -import jax -from jax import numpy as jnp, Array, vmap -import numpy as np - -from config import Config, HyperNeatConfig -from core import Algorithm, Substrate, State, Genome -from utils import Activation, Aggregation -from algorithm.neat import NEAT -from .substrate import analysis_substrate - -class HyperNEAT(Algorithm): - - def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]): - self.config = config - self.neat = neat - self.substrate = substrate - - self.forward_func = None - - def setup(self, randkey, state=State()): - neat_key, randkey = jax.random.split(randkey) - state = state.update( - below_threshold=self.config.hyper_neat.below_threshold, - max_weight=self.config.hyper_neat.max_weight, - ) - state = self.neat.setup(neat_key, state) - state = self.substrate.setup(self.config.substrate, state) - - assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias - assert self.config.hyper_neat.outputs == state.output_coors.shape[0] - - h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) - h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] - h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) - h_conns[:, 0:2] = correspond_keys - - state = state.update( - h_input_idx=h_input_idx, - h_output_idx=h_output_idx, - h_hidden_idx=h_hidden_idx, - h_nodes=h_nodes, - h_conns=h_conns, - query_coors=query_coors, - ) - - self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state) - - return state - def ask(self, state: State): - return state.pop_genomes - - def tell(self, state: State, fitness): - return self.neat.tell(state, fitness) - - def forward(self, inputs: Array, transformed: Array): - return self.forward_func(inputs, transformed) - - def forward_transform(self, state: State, genome: Genome): - t = self.neat.forward_transform(state, genome) - query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t) - - # mute the connection with weight below threshold - query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res) - - # make query res in range [-max_weight, max_weight] - query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res) - query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res) - query_res = query_res / (1 - state.below_threshold) * state.max_weight - - h_conns = state.h_conns.at[:, 2:].set(query_res) - return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns)) - - -class HyperNEATGene: - node_attrs = [] # no node attributes - conn_attrs = ['weight'] - - @staticmethod - def forward_transform(genome: Genome): - N = genome.nodes.shape[0] - u_conns = jnp.zeros((N, N), dtype=jnp.float32) - - in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32) - out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32) - weights = genome.conns[:, 2] - - u_conns = u_conns.at[in_keys, out_keys].set(weights) - return genome.nodes, u_conns - - @staticmethod - def create_forward(config: HyperNeatConfig, state: State): - - act = Activation.name2func[config.activation] - agg = Aggregation.name2func[config.aggregation] - - batch_act, batch_agg = jax.vmap(act), jax.vmap(agg) - - def forward(inputs, transform): - - inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) - nodes, weights = transform - - input_idx = state.h_input_idx - output_idx = state.h_output_idx - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - def body_func(i, values): - values = values.at[input_idx].set(inputs_with_bias) - nodes_ins = values * weights.T - values = batch_agg(nodes_ins) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(values) # z = act(z) - return values - - vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) - return vals[output_idx] - - return forward \ No newline at end of file diff --git a/algorithm/hyper_neat/substrate/__init__.py b/algorithm/hyper_neat/substrate/__init__.py deleted file mode 100644 index a0378ba..0000000 --- a/algorithm/hyper_neat/substrate/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .normal import NormalSubstrate, NormalSubstrateConfig -from .tools import analysis_substrate diff --git a/algorithm/hyper_neat/substrate/normal.py b/algorithm/hyper_neat/substrate/normal.py deleted file mode 100644 index e16eedd..0000000 --- a/algorithm/hyper_neat/substrate/normal.py +++ /dev/null @@ -1,25 +0,0 @@ -from dataclasses import dataclass -from typing import Tuple - -import numpy as np - -from core import Substrate, State -from config import SubstrateConfig - - -@dataclass(frozen=True) -class NormalSubstrateConfig(SubstrateConfig): - input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1)) - hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0)) - output_coors: Tuple[Tuple[float]] = ((0, 1), ) - - -class NormalSubstrate(Substrate): - - @staticmethod - def setup(config: NormalSubstrateConfig, state: State = State()): - return state.update( - input_coors=np.asarray(config.input_coors, dtype=np.float32), - output_coors=np.asarray(config.output_coors, dtype=np.float32), - hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32), - ) diff --git a/algorithm/hyper_neat/substrate/tools.py b/algorithm/hyper_neat/substrate/tools.py deleted file mode 100644 index 21413be..0000000 --- a/algorithm/hyper_neat/substrate/tools.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Type - -import numpy as np - -def analysis_substrate(state): - cd = state.input_coors.shape[1] # coordinate dimensions - si = state.input_coors.shape[0] # input coordinate size - so = state.output_coors.shape[0] # output coordinate size - sh = state.hidden_coors.shape[0] # hidden coordinate size - - input_idx = np.arange(si) - output_idx = np.arange(si, si + so) - hidden_idx = np.arange(si + so, si + so + sh) - - total_conns = si * sh + sh * sh + sh * so - query_coors = np.zeros((total_conns, cd * 2)) - correspond_keys = np.zeros((total_conns, 2)) - - # connect input to hidden - aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) - query_coors[0: si * sh, :] = aux_coors - correspond_keys[0: si * sh, :] = aux_keys - - # connect hidden to hidden - aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) - query_coors[si * sh: si * sh + sh * sh, :] = aux_coors - correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys - - # connect hidden to output - aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) - query_coors[si * sh + sh * sh:, :] = aux_coors - correspond_keys[si * sh + sh * sh:, :] = aux_keys - - return input_idx, output_idx, hidden_idx, query_coors, correspond_keys - - -def cartesian_product(keys1, keys2, coors1, coors2): - len1 = keys1.shape[0] - len2 = keys2.shape[0] - - repeated_coors1 = np.repeat(coors1, len2, axis=0) - repeated_keys1 = np.repeat(keys1, len2) - - tiled_coors2 = np.tile(coors2, (len1, 1)) - tiled_keys2 = np.tile(keys2, len1) - - new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) - correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) - - return new_coors, correspond_keys \ No newline at end of file diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index d6bb53c..6fe56c9 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,2 +1 @@ from .neat import NEAT -from .gene import * diff --git a/algorithm/neat/ga/__init__.py b/algorithm/neat/ga/__init__.py index 4fb8380..cbb0157 100644 --- a/algorithm/neat/ga/__init__.py +++ b/algorithm/neat/ga/__init__.py @@ -1,2 +1,3 @@ from .crossover import crossover -from .mutate import create_mutate +from .mutate import mutate +from .operation import create_next_generation diff --git a/algorithm/neat/ga/crossover.py b/algorithm/neat/ga/crossover.py index 80810f0..88a40bb 100644 --- a/algorithm/neat/ga/crossover.py +++ b/algorithm/neat/ga/crossover.py @@ -9,7 +9,7 @@ def crossover(randkey, genome1: Genome, genome2: Genome): use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ - randkey_1, randkey_2, key= jax.random.split(randkey, 3) + randkey_1, randkey_2, key = jax.random.split(randkey, 3) # crossover nodes keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0] @@ -67,4 +67,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: only gene with the same key will be crossover, thus don't need to consider change key """ r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) \ No newline at end of file + return jnp.where(r > 0.5, g1, g2) diff --git a/algorithm/neat/ga/mutate.py b/algorithm/neat/ga/mutate.py index 0a18141..499d4c2 100644 --- a/algorithm/neat/ga/mutate.py +++ b/algorithm/neat/ga/mutate.py @@ -1,4 +1,4 @@ -from typing import Tuple, Type +from typing import Tuple import jax from jax import Array, numpy as jnp, vmap @@ -8,144 +8,141 @@ from core import State, Gene, Genome from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns -def create_mutate(config: NeatConfig, gene_type: Type[Gene]): +def mutate(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key): """ - Create function to mutate a single genome + Mutate a population of genomes """ + k1, k2 = jax.random.split(randkey) - def mutate_structure(state: State, randkey, genome: Genome, new_node_key): + genome = mutate_structure(config, gene, state, k1, genome, new_node_key) + genome = mutate_values(gene, state, randkey, genome) - def mutate_add_node(key_, genome_: Genome): - i_key, o_key, idx = choice_connection_key(key_, genome_.conns) - - def nothing(): - return genome_ - - def successful_add_node(): - # disable the connection - new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False)) - - # add a new node - new_genome = new_genome.add_node(new_node_key, gene_type.new_node_attrs(state)) - - # add two new connections - new_genome = new_genome.add_conn(i_key, new_node_key, True, gene_type.new_conn_attrs(state)) - new_genome = new_genome.add_conn(new_node_key, o_key, True, gene_type.new_conn_attrs(state)) - - return new_genome - - # if from_idx == I_INT, that means no connection exist, do nothing - return jax.lax.cond(idx == I_INT, nothing, successful_add_node) - - def mutate_delete_node(key_, genome_: Genome): - # TODO: Do we really need to delete a node? - # randomly choose a node - key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=False, allow_output_keys=False) - def nothing(): - return genome_ - - def successful_delete_node(): - # delete the node - new_genome = genome_.delete_node_by_pos(idx) - - # delete all connections - new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None], - jnp.nan, new_genome.conns) - - return new_genome.update_conns(new_conns) - - return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - - def mutate_add_conn(key_, genome_: Genome): - # randomly choose two nodes - k1_, k2_ = jax.random.split(key_, num=2) - i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=False, allow_output_keys=True) - - conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key)) - - def nothing(): - return genome_ - - def successful(): - return genome_.add_conn(i_key, o_key, True, gene_type.new_conn_attrs(state)) - - def already_exist(): - return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True)) + return genome - is_already_exist = conn_pos != I_INT +def mutate_structure(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key): + def mutate_add_node(key_, genome_: Genome): + i_key, o_key, idx = choice_connection_key(key_, genome_.conns) - if config.network_type == 'feedforward': - u_cons = unflatten_conns(genome_.nodes, genome_.conns) - cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) - is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx) + def nothing(): + return genome_ - choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - return jax.lax.switch(choice, [already_exist, nothing, successful]) + def successful_add_node(): + # disable the connection + new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False)) - elif config.network_type == 'recurrent': - return jax.lax.cond(is_already_exist, already_exist, successful) + # add a new node + new_genome = new_genome.add_node(new_node_key, gene.new_node_attrs(state)) - else: - raise ValueError(f"Invalid network type: {config.network_type}") + # add two new connections + new_genome = new_genome.add_conn(i_key, new_node_key, True, gene.new_conn_attrs(state)) + new_genome = new_genome.add_conn(new_node_key, o_key, True, gene.new_conn_attrs(state)) - def mutate_delete_conn(key_, genome_: Genome): - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(key_, genome_.conns) + return new_genome - def nothing(): - return genome_ + # if from_idx == I_INT, that means no connection exist, do nothing + return jax.lax.cond(idx == I_INT, nothing, successful_add_node) - def successfully_delete_connection(): - return genome_.delete_conn_by_pos(idx) + def mutate_delete_node(key_, genome_: Genome): + # TODO: Do we really need to delete a node? + # randomly choose a node + key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=False, allow_output_keys=False) - return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) + def nothing(): + return genome_ - k1, k2, k3, k4 = jax.random.split(randkey, num=4) - r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) + def successful_delete_node(): + # delete the node + new_genome = genome_.delete_node_by_pos(idx) - def no(k, g): - return g + # delete all connections + new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None], + jnp.nan, new_genome.conns) - genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome) - genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome) - genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome) - genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome) + return new_genome.update_conns(new_conns) - return genome + return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - def mutate_values(state: State, randkey, genome: Genome): - k1, k2 = jax.random.split(randkey, num=2) - nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) - conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) + def mutate_add_conn(key_, genome_: Genome): + # randomly choose two nodes + k1_, k2_ = jax.random.split(key_, num=2) + i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=True, allow_output_keys=True) + o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=False, allow_output_keys=True) - nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:] + conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key)) - new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys) - new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys) + def nothing(): + return genome_ - # nan nodes not changed - new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) - new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) + def successful(): + return genome_.add_conn(i_key, o_key, True, gene.new_conn_attrs(state)) - new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs) - new_conns = genome.conns.at[:, 3:].set(new_conns_attrs) + def already_exist(): + return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True)) - return genome.update(new_nodes, new_conns) + is_already_exist = conn_pos != I_INT - def mutate(state, randkey, genome: Genome, new_node_key): - k1, k2 = jax.random.split(randkey) + if config.network_type == 'feedforward': + u_cons = unflatten_conns(genome_.nodes, genome_.conns) + cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) + is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx) - genome = mutate_structure(state, k1, genome, new_node_key) - genome = mutate_values(state, k2, genome) + choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) + return jax.lax.switch(choice, [already_exist, nothing, successful]) - return genome + elif config.network_type == 'recurrent': + return jax.lax.cond(is_already_exist, already_exist, successful) - return mutate + else: + raise ValueError(f"Invalid network type: {config.network_type}") + + def mutate_delete_conn(key_, genome_: Genome): + # randomly choose a connection + i_key, o_key, idx = choice_connection_key(key_, genome_.conns) + + def nothing(): + return genome_ + + def successfully_delete_connection(): + return genome_.delete_conn_by_pos(idx) + + return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) + + k1, k2, k3, k4 = jax.random.split(randkey, num=4) + r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) + + def no(k, g): + return g + + genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome) + genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome) + genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome) + genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome) + + return genome + + +def mutate_values(gene: Gene, state: State, randkey, genome: Genome): + k1, k2 = jax.random.split(randkey, num=2) + nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) + conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) + + nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:] + + new_nodes_attrs = vmap(gene.mutate_node, in_axes=(None, 0, 0))(state, nodes_keys, nodes_attrs) + new_conns_attrs = vmap(gene.mutate_conn, in_axes=(None, 0, 0))(state, conns_keys, conns_attrs) + + # nan nodes not changed + new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) + new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) + + new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs) + new_conns = genome.conns.at[:, 3:].set(new_conns_attrs) + + return genome.update(new_nodes, new_conns) def choice_node_key(rand_key: Array, nodes: Array, @@ -186,4 +183,4 @@ def choice_connection_key(rand_key: Array, conns: Array): i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan) o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan) - return i_key, o_key, idx \ No newline at end of file + return i_key, o_key, idx diff --git a/algorithm/neat/ga/operation.py b/algorithm/neat/ga/operation.py new file mode 100644 index 0000000..d2aff65 --- /dev/null +++ b/algorithm/neat/ga/operation.py @@ -0,0 +1,40 @@ +import jax +from jax import numpy as jnp, vmap + +from config import NeatConfig +from core import Genome, State, Gene +from .mutate import mutate +from .crossover import crossover + + +def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask): + # prepare random keys + pop_size = state.idx2species.shape[0] + new_node_keys = jnp.arange(pop_size) + state.next_node_key + + k1, k2 = jax.random.split(randkey, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner] + lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser] + n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc)) + + # batch mutation + mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0)) + m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes) + pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns) + + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return state.update( + pop_genomes=Genome(pop_nodes, pop_conns), + next_node_key=next_node_key, + ) diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index 9c3c16d..02af6ce 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1,2 +1 @@ from .normal import NormalGene, NormalGeneConfig -from .recurrent import RecurrentGene, RecurrentGeneConfig diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index b07d28c..449eda2 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -6,7 +6,7 @@ from jax import Array, numpy as jnp from config import GeneConfig from core import Gene, Genome, State -from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT +from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg @dataclass(frozen=True) @@ -66,48 +66,51 @@ class NormalGene(Gene): node_attrs = ['bias', 'response', 'aggregation', 'activation'] conn_attrs = ['weight'] - @staticmethod - def setup(config: NormalGeneConfig, state: State = State()): + def __init__(self, config: NormalGeneConfig): + self.config = config + self.act_funcs = [Activation.name2func[name] for name in config.activation_options] + self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] + def setup(self, state: State = State()): return state.update( - bias_init_mean=config.bias_init_mean, - bias_init_std=config.bias_init_std, - bias_mutate_power=config.bias_mutate_power, - bias_mutate_rate=config.bias_mutate_rate, - bias_replace_rate=config.bias_replace_rate, + bias_init_mean=self.config.bias_init_mean, + bias_init_std=self.config.bias_init_std, + bias_mutate_power=self.config.bias_mutate_power, + bias_mutate_rate=self.config.bias_mutate_rate, + bias_replace_rate=self.config.bias_replace_rate, - response_init_mean=config.response_init_mean, - response_init_std=config.response_init_std, - response_mutate_power=config.response_mutate_power, - response_mutate_rate=config.response_mutate_rate, - response_replace_rate=config.response_replace_rate, + response_init_mean=self.config.response_init_mean, + response_init_std=self.config.response_init_std, + response_mutate_power=self.config.response_mutate_power, + response_mutate_rate=self.config.response_mutate_rate, + response_replace_rate=self.config.response_replace_rate, - activation_replace_rate=config.activation_replace_rate, + activation_replace_rate=self.config.activation_replace_rate, activation_default=0, - activation_options=jnp.arange(len(config.activation_options)), + activation_options=jnp.arange(len(self.config.activation_options)), - aggregation_replace_rate=config.aggregation_replace_rate, + aggregation_replace_rate=self.config.aggregation_replace_rate, aggregation_default=0, - aggregation_options=jnp.arange(len(config.aggregation_options)), + aggregation_options=jnp.arange(len(self.config.aggregation_options)), - weight_init_mean=config.weight_init_mean, - weight_init_std=config.weight_init_std, - weight_mutate_power=config.weight_mutate_power, - weight_mutate_rate=config.weight_mutate_rate, - weight_replace_rate=config.weight_replace_rate, + weight_init_mean=self.config.weight_init_mean, + weight_init_std=self.config.weight_init_std, + weight_mutate_power=self.config.weight_mutate_power, + weight_mutate_rate=self.config.weight_mutate_rate, + weight_replace_rate=self.config.weight_replace_rate, ) - @staticmethod - def new_node_attrs(state): + def update(self, state): + pass + + def new_node_attrs(self, state): return jnp.array([state.bias_init_mean, state.response_init_mean, state.activation_default, state.aggregation_default]) - @staticmethod - def new_conn_attrs(state): + def new_conn_attrs(self, state): return jnp.array([state.weight_init_mean]) - @staticmethod - def mutate_node(state, attrs: Array, key): + def mutate_node(self, state, key, attrs: Array): k1, k2, k3, k4 = jax.random.split(key, num=4) bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std, @@ -120,26 +123,22 @@ class NormalGene(Gene): return jnp.array([bias, res, act, agg]) - @staticmethod - def mutate_conn(state, attrs: Array, key): + def mutate_conn(self, state, key, attrs: Array): weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std, state.weight_mutate_power, state.weight_mutate_rate, state.weight_replace_rate) return jnp.array([weight]) - @staticmethod - def distance_node(state, node1: Array, node2: Array): + def distance_node(self, state, node1: Array, node2: Array): # bias + response + activation + aggregation return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \ (node1[3] != node2[3]) + (node1[4] != node2[4]) - @staticmethod - def distance_conn(state, con1: Array, con2: Array): + def distance_conn(self, state, con1: Array, con2: Array): return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight - @staticmethod - def forward_transform(state: State, genome: Genome): + def forward_transform(self, state: State, genome: Genome): u_conns = unflatten_conns(genome.nodes, genome.conns) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) @@ -149,87 +148,46 @@ class NormalGene(Gene): return seqs, genome.nodes, u_conns - @staticmethod - def create_forward(state: State, config: NormalGeneConfig): - activation_funcs = [Activation.name2func[name] for name in config.activation_options] - aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] + def forward(self, state: State, inputs, transformed): + cal_seqs, nodes, cons = transformed - def act(idx, z): - """ - calculate activation function for each node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - res = jax.lax.switch(idx, activation_funcs, z) - return res + input_idx = state.input_idx + output_idx = state.output_idx - def agg(idx, z): - """ - calculate activation function for inputs of node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) + N = nodes.shape[0] + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[input_idx].set(inputs) - def all_nan(): - return 0. + weights = cons[0, :] - def not_all_nan(): - return jax.lax.switch(idx, aggregation_funcs, z) + def cond_fun(carry): + values, idx = carry + return (idx < N) & (cal_seqs[idx] != I_INT) - return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) + def body_func(carry): + values, idx = carry + i = cal_seqs[idx] - def forward(inputs, transformed) -> Array: - """ - forward for single input shaped (input_num, ) + def hit(): + ins = values * weights[:, i] + z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins) + z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias + z = act(nodes[i, 3], z, self.act_funcs) # z = act(z) - :argument inputs: (input_num, ) - :argument cal_seqs: (N, ) - :argument nodes: (N, 5) - :argument connections: (2, N, N) + new_values = values.at[i].set(z) + return new_values - :return (output_num, ) - """ + def miss(): + return values - cal_seqs, nodes, cons = transformed + # the val of input nodes is obtained by the task, not by calculation + values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) - input_idx = state.input_idx - output_idx = state.output_idx + return values, idx + 1 - N = nodes.shape[0] - ini_vals = jnp.full((N,), jnp.nan) - ini_vals = ini_vals.at[input_idx].set(inputs) + vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) - weights = cons[0, :] - - def cond_fun(carry): - values, idx = carry - return (idx < N) & (cal_seqs[idx] != I_INT) - - def body_func(carry): - values, idx = carry - i = cal_seqs[idx] - - def hit(): - ins = values * weights[:, i] - z = agg(nodes[i, 4], ins) # z = agg(ins) - z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - z = act(nodes[i, 3], z) # z = act(z) - - new_values = values.at[i].set(z) - return new_values - - def miss(): - return values - - # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) - - return values, idx + 1 - - vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) - - return vals[output_idx] - - return forward + return vals[output_idx] @staticmethod def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py deleted file mode 100644 index 1d3942b..0000000 --- a/algorithm/neat/gene/recurrent.py +++ /dev/null @@ -1,84 +0,0 @@ -from dataclasses import dataclass - -import jax -from jax import Array, numpy as jnp, vmap - -from .normal import NormalGene, NormalGeneConfig -from core import State, Genome -from utils import Activation, Aggregation, unflatten_conns - - -@dataclass(frozen=True) -class RecurrentGeneConfig(NormalGeneConfig): - activate_times: int = 10 - - def __post_init__(self): - super().__post_init__() - assert self.activate_times > 0 - - -class RecurrentGene(NormalGene): - - @staticmethod - def forward_transform(state: State, genome: Genome): - u_conns = unflatten_conns(genome.nodes, genome.conns) - - # remove un-enable connections and remove enable attr - conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - - return genome.nodes, u_conns - - @staticmethod - def create_forward(state: State, config: RecurrentGeneConfig): - activation_funcs = [Activation.name2func[name] for name in config.activation_options] - aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] - - def act(idx, z): - """ - calculate activation function for each node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - res = jax.lax.switch(idx, activation_funcs, z) - return res - - def agg(idx, z): - """ - calculate activation function for inputs of node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - - def all_nan(): - return 0. - - def not_all_nan(): - return jax.lax.switch(idx, aggregation_funcs, z) - - return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) - - batch_act, batch_agg = vmap(act), vmap(agg) - - def forward(inputs, transform) -> Array: - nodes, cons = transform - - input_idx = state.input_idx - output_idx = state.output_idx - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - weights = cons[0, :] - - def body_func(i, values): - values = values.at[input_idx].set(inputs) - nodes_ins = values * weights.T - values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(nodes[:, 3], values) # z = act(z) - return values - - vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) - return vals[output_idx] - - return forward diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index e928287..f093ed4 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -1,20 +1,18 @@ -from typing import Type - import jax -from jax import numpy as jnp, Array, vmap +from jax import numpy as jnp import numpy as np from config import Config from core import Algorithm, State, Gene, Genome -from .ga import crossover, create_mutate -from .species import SpeciesInfo, update_species, create_speciate +from .ga import create_next_generation +from .species import SpeciesInfo, update_species, speciate class NEAT(Algorithm): - def __init__(self, config: Config, gene_type: Type[Gene]): + def __init__(self, config: Config, gene: Gene): self.config = config - self.gene_type = gene_type + self.gene = gene self.forward_func = None self.tell_func = None @@ -31,8 +29,8 @@ class NEAT(Algorithm): N=self.config.neat.maximum_nodes, C=self.config.neat.maximum_conns, S=self.config.neat.maximum_species, - NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes - CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes + NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes + CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes max_stagnation=self.config.neat.max_stagnation, species_elitism=self.config.neat.species_elitism, spawn_number_change_rate=self.config.neat.spawn_number_change_rate, @@ -46,7 +44,7 @@ class NEAT(Algorithm): output_idx=output_idx, ) - state = self.gene_type.setup(self.config.gene, state) + state = self.gene.setup(state) pop_genomes = self._initialize_genomes(state) species_info = SpeciesInfo.initialize(state) @@ -74,26 +72,32 @@ class NEAT(Algorithm): next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32), ) - self.forward_func = self.gene_type.create_forward(state, self.config.gene) - self.tell_func = self._create_tell() - return jax.device_put(state) - def ask(self, state: State): - """require the population to be evaluated""" + def ask_algorithm(self, state: State): return state.pop_genomes - def tell(self, state: State, fitness): - """update the state of the algorithm""" - return self.tell_func(state, fitness) + def tell_algorithm(self, state: State, fitness): + k1, k2, randkey = jax.random.split(state.randkey, 3) - def forward(self, inputs: Array, transformed: Array): - """the forward function of a single forward transformation""" - return self.forward_func(inputs, transformed) + state = state.update( + generation=state.generation + 1, + randkey=randkey + ) + + state, winner, loser, elite_mask = update_species(state, k1, fitness) + + state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask) + + state = speciate(self.gene, state) + + return state def forward_transform(self, state: State, genome: Genome): - """create the forward transformation of a genome""" - return self.gene_type.forward_transform(state, genome) + return self.gene.forward_transform(state, genome) + + def forward(self, state: State, inputs, genome: Genome): + return self.gene.forward(state, inputs, genome) def _initialize_genomes(self, state): o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes @@ -106,80 +110,21 @@ class NEAT(Algorithm): o_nodes[input_idx, 0] = input_idx o_nodes[output_idx, 0] = output_idx o_nodes[new_node_key, 0] = new_node_key - o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state) - o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state) + o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state) + o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state) input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] o_conns[input_idx, 0:2] = input_conns # in key, out key o_conns[input_idx, 2] = True # enabled - o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state) + o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state) output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] o_conns[output_idx, 0:2] = output_conns # in key, out key o_conns[output_idx, 2] = True # enabled - o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state) + o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state) # repeat origin genome for P times to create population pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) pop_conns = np.tile(o_conns, (state.P, 1, 1)) return Genome(pop_nodes, pop_conns) - - def _create_tell(self): - mutate = create_mutate(self.config.neat, self.gene_type) - - def create_next_generation(state, randkey, winner, loser, elite_mask): - # prepare random keys - pop_size = state.idx2species.shape[0] - new_node_keys = jnp.arange(pop_size) + state.next_node_key - - k1, k2 = jax.random.split(randkey, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner] - lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser] - n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc)) - - # batch mutation - mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0)) - m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes) - pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) - next_node_key = max_node_key + 1 - - return state.update( - pop_genomes=Genome(pop_nodes, pop_conns), - next_node_key=next_node_key, - ) - - speciate = create_speciate(self.gene_type) - - def tell(state, fitness): - """ - Main update function in NEAT. - """ - - k1, k2, randkey = jax.random.split(state.randkey, 3) - - state = state.update( - generation=state.generation + 1, - randkey=randkey - ) - - state, winner, loser, elite_mask = update_species(state, k1, fitness) - - state = create_next_generation(state, k2, winner, loser, elite_mask) - - state = speciate(state) - - return state - - return tell diff --git a/algorithm/neat/species/__init__.py b/algorithm/neat/species/__init__.py index 8717c2e..eda012f 100644 --- a/algorithm/neat/species/__init__.py +++ b/algorithm/neat/species/__init__.py @@ -1,2 +1,2 @@ -from .operations import update_species, create_speciate from .species_info import SpeciesInfo +from .operations import update_species, speciate diff --git a/algorithm/neat/species/distance.py b/algorithm/neat/species/distance.py index 9667e5a..7150672 100644 --- a/algorithm/neat/species/distance.py +++ b/algorithm/neat/species/distance.py @@ -1,73 +1,71 @@ -from typing import Type - from jax import Array, numpy as jnp, vmap from core import Gene -def create_distance(gene_type: Type[Gene]): - def node_distance(state, nodes1: Array, nodes2: Array): - """ - Calculate the distance between nodes of two genomes. - """ - # statistics nodes count of two genomes - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) +def distance(gene: Gene, state, genome1, genome2): + return node_distance(gene, state, genome1.nodes, genome2.nodes) + \ + connection_distance(gene, state, genome1.conns, genome2.conns) - # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) +def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array): + """ + Calculate the distance between nodes of two genomes. + """ + # statistics nodes count of two genomes + node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) + node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) + max_cnt = jnp.maximum(node_cnt1, node_cnt2) - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + # align homologous nodes + # this process is similar to np.intersect1d. + nodes = jnp.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = jnp.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row - # calculate the distance of homologous nodes - hnd = vmap(gene_type.distance_node, in_axes=(None, 0, 0))(state, fr, sr) - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) + # flag location of homologous nodes + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) - val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight + # calculate the count of non_homologous of two genomes + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division + # calculate the distance of homologous nodes + hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr) + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) - def connection_distance(state, cons1: Array, cons2: Array): - """ - Calculate the distance between connections of two genomes. - Similar process as node_distance. - """ - con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) + val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight - cons = jnp.concatenate((cons1, cons2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row + return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - hcd = vmap(gene_type.distance_conn, in_axes=(None, 0, 0))(state, fr, sr) - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) +def connection_distance(gene: Gene, state, cons1: Array, cons2: Array): + """ + Calculate the distance between connections of two genomes. + Similar process as node_distance. + """ + con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) + con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) + max_cnt = jnp.maximum(con_cnt1, con_cnt2) - val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight + cons = jnp.concatenate((cons1, cons2), axis=0) + keys = cons[:, :2] + sorted_indices = jnp.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row - return jnp.where(max_cnt == 0, 0, val / max_cnt) + # both genome has such connection + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - def distance(state, genome1, genome2): - return node_distance(state, genome1.nodes, genome2.nodes) + connection_distance(state, genome1.conns, genome2.conns) + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) + hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr) + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) - return distance + val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight + + return jnp.where(max_cnt == 0, 0, val / max_cnt) diff --git a/algorithm/neat/species/operations.py b/algorithm/neat/species/operations.py index ce3401a..b90c741 100644 --- a/algorithm/neat/species/operations.py +++ b/algorithm/neat/species/operations.py @@ -1,11 +1,9 @@ -from typing import Type - import jax from jax import numpy as jnp, vmap -from core import Gene, Genome +from core import Gene, Genome, State from utils import rank_elements, fetch_first -from .distance import create_distance +from .distance import distance from .species_info import SpeciesInfo @@ -170,154 +168,149 @@ def create_crossover_pair(state, randkey, spawn_number, fitness): return winner, loser, elite_mask -def create_speciate(gene_type: Type[Gene]): - distance = create_distance(gene_type) +def speciate(gene: Gene, state: State): + pop_size, species_size = state.idx2species.shape[0], state.species_info.size() - def speciate(state): - pop_size, species_size = state.idx2species.shape[0], state.species_info.size() + # prepare distance functions + o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population - # prepare distance functions - o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population + # idx to specie key + idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species - # idx to specie key - idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species + # the distance between genomes to its center genomes + o2c_distances = jnp.full((pop_size,), jnp.inf) - # the distance between genomes to its center genomes - o2c_distances = jnp.full((pop_size,), jnp.inf) + # step 1: find new centers + def cond_func(carry): + i, i2s, cgs, o2c = carry - # step 1: find new centers - def cond_func(carry): - i, i2s, cgs, o2c = carry + return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing - return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing + def body_func(carry): + i, i2s, cgs, o2c = carry - def body_func(carry): - i, i2s, cgs, o2c = carry + distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes) - distances = o2p_distance_func(state, cgs[i], state.pop_genomes) + # find the closest one + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - # find the closest one - closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) + i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) + cgs = cgs.set(i, state.pop_genomes[closest_idx]) - i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) - cgs = cgs.set(i, state.pop_genomes[closest_idx]) + # the genome with closest_idx will become the new center, thus its distance to center is 0. + o2c = o2c.at[closest_idx].set(0) - # the genome with closest_idx will become the new center, thus its distance to center is 0. - o2c = o2c.at[closest_idx].set(0) + return i + 1, i2s, cgs, o2c - return i + 1, i2s, cgs, o2c + _, idx2species, center_genomes, o2c_distances = \ + jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) - _, idx2species, center_genomes, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) + state = state.update( + idx2species=idx2species, + center_genomes=center_genomes, + ) - state = state.update( - idx2species=idx2species, - center_genomes=center_genomes, + # part 2: assign members to each species + def cond_func(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + current_species_existed = ~jnp.isnan(sk[i]) + not_all_assigned = jnp.any(jnp.isnan(i2s)) + not_reach_species_upper_bounds = i < species_size + return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + + def body_func(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + _, i2s, cgs, sk, o2c, nsk = jax.lax.cond( + jnp.isnan(sk[i]), # whether the current species is existing or not + create_new_species, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cgs, sk, o2c, nsk) ) - # part 2: assign members to each species - def cond_func(carry): - i, i2s, cgs, sk, o2c, nsk = carry + return i + 1, i2s, cgs, sk, o2c, nsk - current_species_existed = ~jnp.isnan(sk[i]) - not_all_assigned = jnp.any(jnp.isnan(i2s)) - not_reach_species_upper_bounds = i < species_size - return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + def create_new_species(carry): + i, i2s, cgs, sk, o2c, nsk = carry - def body_func(carry): - i, i2s, cgs, sk, o2c, nsk = carry + # pick the first one who has not been assigned to any species + idx = fetch_first(jnp.isnan(i2s)) - _, i2s, cgs, sk, o2c, nsk = jax.lax.cond( - jnp.isnan(sk[i]), # whether the current species is existing or not - create_new_species, # if not existing, create a new specie - update_exist_specie, # if existing, update the specie - (i, i2s, cgs, sk, o2c, nsk) - ) + # assign it to the new species + # [key, best score, last update generation, member_count] + sk = sk.at[i].set(nsk) + i2s = i2s.at[idx].set(nsk) + o2c = o2c.at[idx].set(0) - return i + 1, i2s, cgs, sk, o2c, nsk + # update center genomes + cgs = cgs.set(i, state.pop_genomes[idx]) - def create_new_species(carry): - i, i2s, cgs, sk, o2c, nsk = carry + i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) - # pick the first one who has not been assigned to any species - idx = fetch_first(jnp.isnan(i2s)) + # when a new species is created, it needs to be updated, thus do not change i + return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key - # assign it to the new species - # [key, best score, last update generation, member_count] - sk = sk.at[i].set(nsk) - i2s = i2s.at[idx].set(nsk) - o2c = o2c.at[idx].set(0) + def update_exist_specie(carry): + i, i2s, cgs, sk, o2c, nsk = carry - # update center genomes - cgs = cgs.set(i, state.pop_genomes[idx]) + i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) - i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) + # turn to next species + return i + 1, i2s, cgs, sk, o2c, nsk - # when a new species is created, it needs to be updated, thus do not change i - return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key + def speciate_by_threshold(i, i2s, cgs, sk, o2c): + # distance between such center genome and ppo genomes - def update_exist_specie(carry): - i, i2s, cgs, sk, o2c, nsk = carry + o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes) + close_enough_mask = o2p_distance < state.compatibility_threshold - i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) + # when a genome is not assigned or the distance between its current center is bigger than this center + cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) + # jax.debug.print("{}", o2p_distance) + mask = close_enough_mask & cacheable_mask - # turn to next species - return i + 1, i2s, cgs, sk, o2c, nsk + # update species info + i2s = jnp.where(mask, sk[i], i2s) - def speciate_by_threshold(i, i2s, cgs, sk, o2c): - # distance between such center genome and ppo genomes + # update distance between centers + o2c = jnp.where(mask, o2p_distance, o2c) - o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes) - close_enough_mask = o2p_distance < state.compatibility_threshold + return i2s, o2c - # when a genome is not assigned or the distance between its current center is bigger than this center - cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) - # jax.debug.print("{}", o2p_distance) - mask = close_enough_mask & cacheable_mask + # update idx2species + _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop( + cond_func, + body_func, + (0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, + state.next_species_key) + ) - # update species info - i2s = jnp.where(mask, sk[i], i2s) + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition can only happen when the number of species is reached species upper bounds + idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) - # update distance between centers - o2c = jnp.where(mask, o2p_distance, o2c) + # complete info of species which is created in this generation + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness) + best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness) + last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved) - return i2s, o2c + # update members count + def count_members(idx): + key = species_keys[idx] + count = jnp.sum(idx2species == key, dtype=jnp.float32) + count = jnp.where(jnp.isnan(key), jnp.nan, count) - # update idx2species - _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop( - cond_func, - body_func, - (0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key) - ) + return count + member_count = vmap(count_members)(jnp.arange(species_size)) - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition can only happen when the number of species is reached species upper bounds - idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) - - # complete info of species which is created in this generation - new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness) - best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness) - last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved) - - # update members count - def count_members(idx): - key = species_keys[idx] - count = jnp.sum(idx2species == key, dtype=jnp.float32) - count = jnp.where(jnp.isnan(key), jnp.nan, count) - - return count - - member_count = vmap(count_members)(jnp.arange(species_size)) - - return state.update( - species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count), - idx2species=idx2species, - center_genomes=center_genomes, - next_species_key=next_species_key - ) - - return speciate + return state.update( + species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count), + idx2species=idx2species, + center_genomes=center_genomes, + next_species_key=next_species_key + ) def argmin_with_mask(arr, mask): diff --git a/algorithm/neat/species/species_info.py b/algorithm/neat/species/species_info.py index d2e4788..2dc1c86 100644 --- a/algorithm/neat/species/species_info.py +++ b/algorithm/neat/species/species_info.py @@ -2,6 +2,7 @@ from jax.tree_util import register_pytree_node_class import numpy as np import jax.numpy as jnp + @register_pytree_node_class class SpeciesInfo: @@ -44,7 +45,6 @@ class SpeciesInfo: def size(self): return self.species_keys.shape[0] - def tree_flatten(self): children = self.species_keys, self.best_fitness, self.last_improved, self.member_count aux_data = None diff --git a/config/__init__.py b/config/__init__.py index 473966f..d085c3a 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -1,2 +1 @@ -from .config import * - +from .config import * \ No newline at end of file diff --git a/config/config.py b/config/config.py index ba54ff7..d68accd 100644 --- a/config/config.py +++ b/config/config.py @@ -86,6 +86,7 @@ class HyperNeatConfig: class GeneConfig: pass + @dataclass(frozen=True) class SubstrateConfig: pass diff --git a/config/default_config.ini b/config/default_config.ini deleted file mode 100644 index 921776d..0000000 --- a/config/default_config.ini +++ /dev/null @@ -1,76 +0,0 @@ -[basic] -random_seed = 0 -generation_limit = 1000 -fitness_threshold = 3.9999 -num_inputs = 2 -num_outputs = 1 - -[neat] -network_type = "feedforward" -activate_times = 5 -maximum_nodes = 50 -maximum_conns = 50 -maximum_species = 10 - -compatibility_disjoint = 1.0 -compatibility_weight = 0.5 -conn_add_prob = 0.4 -conn_delete_prob = 0 -node_add_prob = 0.2 -node_delete_prob = 0 - -[hyperneat] -below_threshold = 0.2 -max_weight = 3 -h_activation = "sigmoid" -h_aggregation = "sum" -h_activate_times = 5 - -[substrate] -input_coors = [[-1, 1], [0, 1], [1, 1]] -hidden_coors = [[-1, 0], [0, 0], [1, 0]] -output_coors = [[0, -1]] - -[species] -compatibility_threshold = 3.0 -species_elitism = 2 -max_stagnation = 15 -genome_elitism = 2 -survival_threshold = 0.2 -min_species_size = 1 -spawn_number_change_rate = 0.5 - -[gene] -# bias -bias_init_mean = 0.0 -bias_init_std = 1.0 -bias_mutate_power = 0.5 -bias_mutate_rate = 0.7 -bias_replace_rate = 0.1 - -# response -response_init_mean = 1.0 -response_init_std = 0.0 -response_mutate_power = 0.0 -response_mutate_rate = 0.0 -response_replace_rate = 0.0 - -# activation -activation_default = "sigmoid" -activation_option_names = ["tanh"] -activation_replace_rate = 0.0 - -# aggregation -aggregation_default = "sum" -aggregation_option_names = ["sum"] -aggregation_replace_rate = 0.0 - -# weight -weight_init_mean = 0.0 -weight_init_std = 1.0 -weight_mutate_power = 0.5 -weight_mutate_rate = 0.8 -weight_replace_rate = 0.1 - -[visualize] -renumber_nodes = True \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py index 8e16999..ad0ee9c 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -2,4 +2,4 @@ from .algorithm import Algorithm from .state import State from .genome import Genome from .gene import Gene -from .substrate import Substrate +from .substrate import Substrate \ No newline at end of file diff --git a/core/algorithm.py b/core/algorithm.py index 51d2fec..0f575a0 100644 --- a/core/algorithm.py +++ b/core/algorithm.py @@ -1,28 +1,50 @@ -from jax import Array +from functools import partial +import jax from .state import State from .genome import Genome -EMPTY = lambda *args: args - class Algorithm: def setup(self, randkey, state: State = State()): """initialize the state of the algorithm""" - pass + raise NotImplementedError + + @partial(jax.jit, static_argnums=(0,)) def ask(self, state: State): """require the population to be evaluated""" - pass + return self.ask_algorithm(state) + + @partial(jax.jit, static_argnums=(0,)) def tell(self, state: State, fitness): """update the state of the algorithm""" - pass - def forward(self, inputs: Array, transformed: Array): - """the forward function of a single forward transformation""" - pass + return self.tell_algorithm(state, fitness) + + @partial(jax.jit, static_argnums=(0,)) + def transform(self, state: State, genome: Genome): + """transform the genome into a neural network""" + + return self.forward_transform(state, genome) + + @partial(jax.jit, static_argnums=(0,)) + def act(self, state: State, inputs, genome: Genome): + return self.forward(state, inputs, genome) def forward_transform(self, state: State, genome: Genome): - """create the forward transformation of a genome""" - pass + raise NotImplementedError + + def forward(self, state: State, inputs, genome: Genome): + raise NotImplementedError + + def ask_algorithm(self, state: State): + """ask the specific algorithm for a new population""" + + raise NotImplementedError + + def tell_algorithm(self, state: State, fitness): + """tell the specific algorithm the fitness of the population""" + + raise NotImplementedError diff --git a/core/gene.py b/core/gene.py index b1b2704..7c6d04e 100644 --- a/core/gene.py +++ b/core/gene.py @@ -1,46 +1,37 @@ -from jax import Array, numpy as jnp - from config import GeneConfig from .state import State -from .genome import Genome class Gene: node_attrs = [] conn_attrs = [] - @staticmethod - def setup(config: GeneConfig, state: State): - return state + def setup(self, state=State()): + raise NotImplementedError - @staticmethod - def new_node_attrs(state: State): - return jnp.zeros(0) + def update(self, state): + raise NotImplementedError - @staticmethod - def new_conn_attrs(state: State): - return jnp.zeros(0) + def new_node_attrs(self, state: State): + raise NotImplementedError - @staticmethod - def mutate_node(state: State, attrs: Array, randkey: Array): - return attrs + def new_conn_attrs(self, state: State): + raise NotImplementedError - @staticmethod - def mutate_conn(state: State, attrs: Array, randkey: Array): - return attrs + def mutate_node(self, state: State, randkey, node_attrs): + raise NotImplementedError - @staticmethod - def distance_node(state: State, node1: Array, node2: Array): - return node1 + def mutate_conn(self, state: State, randkey, conn_attrs): + raise NotImplementedError - @staticmethod - def distance_conn(state: State, conn1: Array, conn2: Array): - return conn1 + def distance_node(self, state: State, node_attrs1, node_attrs2): + raise NotImplementedError - @staticmethod - def forward_transform(state: State, genome: Genome): - return jnp.zeros(0) # transformed + def distance_conn(self, state: State, conn_attrs1, conn_attrs2): + raise NotImplementedError - @staticmethod - def create_forward(state: State, config: GeneConfig): - return lambda *args: args # forward function + def forward_transform(self, state: State, genome): + raise NotImplementedError + + def forward(self, state: State, inputs, transform): + raise NotImplementedError diff --git a/core/genome.py b/core/genome.py index 75d3267..0153bca 100644 --- a/core/genome.py +++ b/core/genome.py @@ -84,4 +84,3 @@ class Genome: def tree_unflatten(cls, aux_data, children): return cls(*children) - diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 0000000..8eef82a --- /dev/null +++ b/examples/test.py @@ -0,0 +1,24 @@ +from functools import partial +import jax + + +class A: + def __init__(self): + self.a = 1 + self.b = 2 + self.isTrue = False + + @partial(jax.jit, static_argnums=(0,)) + def step(self): + if self.isTrue: + return self.a + 1 + else: + return self.b + 1 + + +AA = A() +print(AA.step(), hash(AA)) +print(AA.step(), hash(AA)) +print(AA.step(), hash(AA)) +AA.a = (2, 3, 4) +print(AA.step(), hash(AA)) diff --git a/examples/xor.py b/examples/xor.py index 18afeef..6ee22b9 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -3,7 +3,8 @@ import numpy as np from config import Config, BasicConfig, NeatConfig from pipeline import Pipeline -from algorithm import NEAT, NormalGene, NormalGeneConfig +from algorithm import NEAT +from algorithm.neat.gene import NormalGene, NormalGeneConfig xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -23,15 +24,15 @@ def evaluate(forward_func): if __name__ == '__main__': config = Config( basic=BasicConfig( - fitness_target=3.99999, + fitness_target=3.9999999, pop_size=10000 ), neat=NeatConfig( maximum_nodes=20, maximum_conns=50, - ), - gene=NormalGeneConfig() + ) ) - algorithm = NEAT(config, NormalGene) + normal_gene = NormalGene(NormalGeneConfig()) + algorithm = NEAT(config, normal_gene) pipeline = Pipeline(config, algorithm) pipeline.auto_run(evaluate) diff --git a/examples/xor_hyperNEAT.py b/examples/xor_hyperNEAT.py deleted file mode 100644 index 8596904..0000000 --- a/examples/xor_hyperNEAT.py +++ /dev/null @@ -1,49 +0,0 @@ -import jax -import numpy as np - -from config import Config, BasicConfig, NeatConfig -from pipeline import Pipeline -from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig -from algorithm import HyperNEAT, NormalSubstrate, NormalSubstrateConfig - - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - fitness_target=3.99999, - pop_size=100 - ), - neat=NeatConfig( - network_type="recurrent", - maximum_nodes=50, - maximum_conns=100, - inputs=4, - outputs=1 - - ), - gene=RecurrentGeneConfig( - activation_default="tanh", - activation_options=("tanh", ), - ), - substrate=NormalSubstrateConfig(), - ) - neat = NEAT(config, RecurrentGene) - hyperNEAT = HyperNEAT(config, neat, NormalSubstrate) - - pipeline = Pipeline(config, hyperNEAT) - pipeline.auto_run(evaluate) diff --git a/examples/xor_recurrent.py b/examples/xor_recurrent.py deleted file mode 100644 index bfe6e20..0000000 --- a/examples/xor_recurrent.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax -import numpy as np - -from config import Config, BasicConfig, NeatConfig -from pipeline import Pipeline -from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig - - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - fitness_target=3.99999, - pop_size=10000 - ), - neat=NeatConfig( - network_type="recurrent", - maximum_nodes=50, - maximum_conns=100 - ), - gene=RecurrentGeneConfig() - ) - algorithm = NEAT(config, RecurrentGene) - pipeline = Pipeline(config, algorithm) - pipeline.auto_run(evaluate) diff --git a/pipeline.py b/pipeline.py index 5dfd085..a4f9c9f 100644 --- a/pipeline.py +++ b/pipeline.py @@ -27,15 +27,15 @@ class Pipeline: self.evaluate_time = 0 - self.forward_func = jit(self.algorithm.forward) - self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) - self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) + self.act_func = jit(self.algorithm.act) + self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None))) + self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0))) self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0))) self.tell_func = jit(self.algorithm.tell) def ask(self): pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes) - return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) + return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms) def tell(self, fitness): # self.state = self.tell_func(self.state, fitness) @@ -80,8 +80,4 @@ class Pipeline: print(f"Generation: {self.state.generation}", f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") - - - - + f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index af946a4..9820a71 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,35 @@ -from .activation import Activation -from .aggregation import Aggregation +from .activation import Activation, act +from .aggregation import Aggregation, agg from .tools import * -from .graph import * \ No newline at end of file +from .graph import * + +Activation.name2func = { + 'sigmoid': Activation.sigmoid_act, + 'tanh': Activation.tanh_act, + 'sin': Activation.sin_act, + 'gauss': Activation.gauss_act, + 'relu': Activation.relu_act, + 'elu': Activation.elu_act, + 'lelu': Activation.lelu_act, + 'selu': Activation.selu_act, + 'softplus': Activation.softplus_act, + 'identity': Activation.identity_act, + 'clamped': Activation.clamped_act, + 'inv': Activation.inv_act, + 'log': Activation.log_act, + 'exp': Activation.exp_act, + 'abs': Activation.abs_act, + 'hat': Activation.hat_act, + 'square': Activation.square_act, + 'cube': Activation.cube_act, +} + +Aggregation.name2func = { + 'sum': Aggregation.sum_agg, + 'product': Aggregation.product_agg, + 'max': Aggregation.max_agg, + 'min': Aggregation.min_agg, + 'maxabs': Aggregation.maxabs_agg, + 'median': Aggregation.median_agg, + 'mean': Aggregation.mean_agg, +} diff --git a/utils/activation.py b/utils/activation.py index 6fdfaa1..f580d57 100644 --- a/utils/activation.py +++ b/utils/activation.py @@ -1,8 +1,8 @@ +import jax import jax.numpy as jnp class Activation: - name2func = {} @staticmethod @@ -89,23 +89,11 @@ class Activation: return z ** 3 -Activation.name2func = { - 'sigmoid': Activation.sigmoid_act, - 'tanh': Activation.tanh_act, - 'sin': Activation.sin_act, - 'gauss': Activation.gauss_act, - 'relu': Activation.relu_act, - 'elu': Activation.elu_act, - 'lelu': Activation.lelu_act, - 'selu': Activation.selu_act, - 'softplus': Activation.softplus_act, - 'identity': Activation.identity_act, - 'clamped': Activation.clamped_act, - 'inv': Activation.inv_act, - 'log': Activation.log_act, - 'exp': Activation.exp_act, - 'abs': Activation.abs_act, - 'hat': Activation.hat_act, - 'square': Activation.square_act, - 'cube': Activation.cube_act, -} \ No newline at end of file +def act(idx, z, act_funcs): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + res = jax.lax.switch(idx, act_funcs, z) + return res diff --git a/utils/aggregation.py b/utils/aggregation.py index 6868b6b..86c686a 100644 --- a/utils/aggregation.py +++ b/utils/aggregation.py @@ -1,8 +1,8 @@ +import jax import jax.numpy as jnp class Aggregation: - name2func = {} @staticmethod @@ -52,12 +52,16 @@ class Aggregation: return mean_without_zeros -Aggregation.name2func = { - 'sum': Aggregation.sum_agg, - 'product': Aggregation.product_agg, - 'max': Aggregation.max_agg, - 'min': Aggregation.min_agg, - 'maxabs': Aggregation.maxabs_agg, - 'median': Aggregation.median_agg, - 'mean': Aggregation.mean_agg, -} +def agg(idx, z, agg_funcs): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + + def all_nan(): + return 0. + + def not_all_nan(): + return jax.lax.switch(idx, agg_funcs, z) + + return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)