From aac41a089d58790bc15a08647fb4dbfd6daf9481 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 27 Jan 2024 00:52:39 +0800 Subject: [PATCH] new architecture --- algorithm/__init__.py | 4 +- algorithm/base.py | 24 + algorithm/hyperneat/__init__.py | 2 - algorithm/hyperneat/hyperneat.py | 113 ----- algorithm/hyperneat/substrate/__init__.py | 2 - algorithm/hyperneat/substrate/normal.py | 25 -- algorithm/hyperneat/substrate/tools.py | 49 --- algorithm/neat/__init__.py | 3 +- algorithm/neat/ga/__init__.py | 5 +- algorithm/neat/ga/crossover.py | 70 --- algorithm/neat/ga/crossover/__init__.py | 2 + algorithm/neat/ga/crossover/base.py | 3 + algorithm/neat/ga/crossover/default.py | 66 +++ algorithm/neat/ga/mutate.py | 186 -------- algorithm/neat/ga/mutation/__init__.py | 2 + algorithm/neat/ga/mutation/base.py | 3 + algorithm/neat/ga/mutation/default.py | 201 +++++++++ algorithm/neat/ga/operation.py | 40 -- algorithm/neat/gene/__init__.py | 6 +- algorithm/neat/gene/base.py | 23 + algorithm/neat/gene/conn/__init__.py | 2 + algorithm/neat/gene/conn/base.py | 12 + algorithm/neat/gene/conn/default.py | 51 +++ algorithm/neat/gene/node/__init__.py | 2 + algorithm/neat/gene/node/base.py | 12 + algorithm/neat/gene/node/default.py | 96 ++++ algorithm/neat/gene/normal.py | 210 --------- algorithm/neat/gene/recurrent.py | 57 --- algorithm/neat/genome/__init__.py | 3 + algorithm/neat/genome/base.py | 66 +++ algorithm/neat/genome/default.py | 75 ++++ algorithm/neat/genome/recurrent.py | 58 +++ algorithm/neat/neat.py | 176 +++----- algorithm/neat/species/__init__.py | 4 +- algorithm/neat/species/base.py | 14 + algorithm/neat/species/default.py | 514 ++++++++++++++++++++++ algorithm/neat/species/distance.py | 71 --- algorithm/neat/species/operations.py | 319 -------------- algorithm/neat/species/species_info.py | 55 --- config/__init__.py | 1 - config/config.py | 107 ----- core/__init__.py | 6 - core/algorithm.py | 50 --- core/gene.py | 40 -- core/genome.py | 90 ---- core/problem.py | 29 -- core/substrate.py | 8 - examples/brax/ant.py | 2 +- pipeline.py | 53 +-- problem/__init__.py | 1 + problem/base.py | 44 ++ problem/func_fit/__init__.py | 2 +- problem/func_fit/func_fit.py | 39 +- problem/func_fit/xor.py | 7 +- problem/func_fit/xor3d.py | 11 +- problem/rl_env/brax_env.py | 31 +- problem/rl_env/gymnax_env.py | 23 +- problem/rl_env/rl_jit.py | 27 +- t.py | 64 +++ test/__init__.py | 0 test/test_genome.py | 113 +++++ utils/__init__.py | 3 +- utils/aggregation.py | 12 +- {core => utils}/state.py | 0 utils/tools.py | 45 +- 65 files changed, 1651 insertions(+), 1783 deletions(-) create mode 100644 algorithm/base.py delete mode 100644 algorithm/hyperneat/__init__.py delete mode 100644 algorithm/hyperneat/hyperneat.py delete mode 100644 algorithm/hyperneat/substrate/__init__.py delete mode 100644 algorithm/hyperneat/substrate/normal.py delete mode 100644 algorithm/hyperneat/substrate/tools.py delete mode 100644 algorithm/neat/ga/crossover.py create mode 100644 algorithm/neat/ga/crossover/__init__.py create mode 100644 algorithm/neat/ga/crossover/base.py create mode 100644 algorithm/neat/ga/crossover/default.py delete mode 100644 algorithm/neat/ga/mutate.py create mode 100644 algorithm/neat/ga/mutation/__init__.py create mode 100644 algorithm/neat/ga/mutation/base.py create mode 100644 algorithm/neat/ga/mutation/default.py delete mode 100644 algorithm/neat/ga/operation.py create mode 100644 algorithm/neat/gene/base.py create mode 100644 algorithm/neat/gene/conn/__init__.py create mode 100644 algorithm/neat/gene/conn/base.py create mode 100644 algorithm/neat/gene/conn/default.py create mode 100644 algorithm/neat/gene/node/__init__.py create mode 100644 algorithm/neat/gene/node/base.py create mode 100644 algorithm/neat/gene/node/default.py delete mode 100644 algorithm/neat/gene/normal.py delete mode 100644 algorithm/neat/gene/recurrent.py create mode 100644 algorithm/neat/genome/__init__.py create mode 100644 algorithm/neat/genome/base.py create mode 100644 algorithm/neat/genome/default.py create mode 100644 algorithm/neat/genome/recurrent.py create mode 100644 algorithm/neat/species/base.py create mode 100644 algorithm/neat/species/default.py delete mode 100644 algorithm/neat/species/distance.py delete mode 100644 algorithm/neat/species/operations.py delete mode 100644 algorithm/neat/species/species_info.py delete mode 100644 config/__init__.py delete mode 100644 config/config.py delete mode 100644 core/__init__.py delete mode 100644 core/algorithm.py delete mode 100644 core/gene.py delete mode 100644 core/genome.py delete mode 100644 core/problem.py delete mode 100644 core/substrate.py create mode 100644 problem/base.py create mode 100644 t.py create mode 100644 test/__init__.py create mode 100644 test/test_genome.py rename {core => utils}/state.py (100%) diff --git a/algorithm/__init__.py b/algorithm/__init__.py index e2e54c0..b2f3695 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1,2 +1,2 @@ -from .neat import NEAT -from .hyperneat import HyperNEAT +from .base import BaseAlgorithm +from .neat import NEAT \ No newline at end of file diff --git a/algorithm/base.py b/algorithm/base.py new file mode 100644 index 0000000..36789c8 --- /dev/null +++ b/algorithm/base.py @@ -0,0 +1,24 @@ +from utils import State + + +class BaseAlgorithm: + + def setup(self, randkey): + """initialize the state of the algorithm""" + + raise NotImplementedError + + def ask(self, state: State): + """require the population to be evaluated""" + raise NotImplementedError + + def tell(self, state: State, fitness): + """update the state of the algorithm""" + raise NotImplementedError + + def transform(self, state: State): + """transform the genome into a neural network""" + raise NotImplementedError + + def forward(self, inputs, transformed): + raise NotImplementedError \ No newline at end of file diff --git a/algorithm/hyperneat/__init__.py b/algorithm/hyperneat/__init__.py deleted file mode 100644 index 8d106fb..0000000 --- a/algorithm/hyperneat/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .hyperneat import HyperNEAT -from .substrate import NormalSubstrate, NormalSubstrateConfig diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py deleted file mode 100644 index 0ef08cf..0000000 --- a/algorithm/hyperneat/hyperneat.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Type - -import jax -from jax import numpy as jnp, Array, vmap -import numpy as np - -from config import Config, HyperNeatConfig -from core import Algorithm, Substrate, State, Genome, Gene -from .substrate import analysis_substrate -from algorithm import NEAT - - -class HyperNEAT(Algorithm): - - def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]): - self.config = config - self.neat = NEAT(config, gene) - self.substrate = substrate - - def setup(self, randkey, state=State()): - neat_key, randkey = jax.random.split(randkey) - state = state.update( - below_threshold=self.config.hyperneat.below_threshold, - max_weight=self.config.hyperneat.max_weight, - ) - state = self.neat.setup(neat_key, state) - state = self.substrate.setup(self.config.substrate, state) - - assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias - assert self.config.hyperneat.outputs == state.output_coors.shape[0] - - h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) - h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] - h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) - h_conns[:, 0:2] = correspond_keys - - state = state.update( - h_input_idx=h_input_idx, - h_output_idx=h_output_idx, - h_hidden_idx=h_hidden_idx, - h_nodes=h_nodes, - h_conns=h_conns, - query_coors=query_coors, - ) - - return state - - def ask_algorithm(self, state: State): - return state.pop_genomes - - def tell_algorithm(self, state: State, fitness): - return self.neat.tell(state, fitness) - - def forward(self, state, inputs: Array, transformed: Array): - return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed) - - def forward_transform(self, state: State, genome: Genome): - t = self.neat.forward_transform(state, genome) - query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t) - - # mute the connection with weight below threshold - query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res) - - # make query res in range [-max_weight, max_weight] - query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res) - query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res) - query_res = query_res / (1 - state.below_threshold) * state.max_weight - - h_conns = state.h_conns.at[:, 2:].set(query_res) - - return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns)) - - -class HyperNEATGene: - node_attrs = [] # no node attributes - conn_attrs = ['weight'] - - @staticmethod - def forward_transform(genome: Genome): - N = genome.nodes.shape[0] - u_conns = jnp.zeros((N, N), dtype=jnp.float32) - - in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32) - out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32) - weights = genome.conns[:, 2] - - u_conns = u_conns.at[in_keys, out_keys].set(weights) - return genome.nodes, u_conns - - @staticmethod - def forward(config: HyperNeatConfig, state: State, inputs, transformed): - batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation) - - nodes, weights = transformed - - inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) - - input_idx = state.h_input_idx - output_idx = state.h_output_idx - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - def body_func(i, values): - values = values.at[input_idx].set(inputs_with_bias) - nodes_ins = values * weights.T - values = batch_agg(nodes_ins) # z = agg(ins) - # values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(values) # z = act(z) - return values - - vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) - return vals[output_idx] diff --git a/algorithm/hyperneat/substrate/__init__.py b/algorithm/hyperneat/substrate/__init__.py deleted file mode 100644 index 035c39f..0000000 --- a/algorithm/hyperneat/substrate/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .normal import NormalSubstrate, NormalSubstrateConfig -from .tools import analysis_substrate \ No newline at end of file diff --git a/algorithm/hyperneat/substrate/normal.py b/algorithm/hyperneat/substrate/normal.py deleted file mode 100644 index c06edc3..0000000 --- a/algorithm/hyperneat/substrate/normal.py +++ /dev/null @@ -1,25 +0,0 @@ -from dataclasses import dataclass -from typing import Tuple - -import numpy as np - -from core import Substrate, State -from config import SubstrateConfig - - -@dataclass(frozen=True) -class NormalSubstrateConfig(SubstrateConfig): - input_coors: Tuple = ((-1, -1), (0, -1), (1, -1)) - hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0)) - output_coors: Tuple = ((0, 1),) - - -class NormalSubstrate(Substrate): - - @staticmethod - def setup(config: NormalSubstrateConfig, state: State = State()): - return state.update( - input_coors=np.asarray(config.input_coors, dtype=np.float32), - output_coors=np.asarray(config.output_coors, dtype=np.float32), - hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32), - ) diff --git a/algorithm/hyperneat/substrate/tools.py b/algorithm/hyperneat/substrate/tools.py deleted file mode 100644 index 8bc4959..0000000 --- a/algorithm/hyperneat/substrate/tools.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np - - -def analysis_substrate(state): - cd = state.input_coors.shape[1] # coordinate dimensions - si = state.input_coors.shape[0] # input coordinate size - so = state.output_coors.shape[0] # output coordinate size - sh = state.hidden_coors.shape[0] # hidden coordinate size - - input_idx = np.arange(si) - output_idx = np.arange(si, si + so) - hidden_idx = np.arange(si + so, si + so + sh) - - total_conns = si * sh + sh * sh + sh * so - query_coors = np.zeros((total_conns, cd * 2)) - correspond_keys = np.zeros((total_conns, 2)) - - # connect input to hidden - aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) - query_coors[0: si * sh, :] = aux_coors - correspond_keys[0: si * sh, :] = aux_keys - - # connect hidden to hidden - aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) - query_coors[si * sh: si * sh + sh * sh, :] = aux_coors - correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys - - # connect hidden to output - aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) - query_coors[si * sh + sh * sh:, :] = aux_coors - correspond_keys[si * sh + sh * sh:, :] = aux_keys - - return input_idx, output_idx, hidden_idx, query_coors, correspond_keys - - -def cartesian_product(keys1, keys2, coors1, coors2): - len1 = keys1.shape[0] - len2 = keys2.shape[0] - - repeated_coors1 = np.repeat(coors1, len2, axis=0) - repeated_keys1 = np.repeat(keys1, len2) - - tiled_coors2 = np.tile(coors2, (len1, 1)) - tiled_keys2 = np.tile(keys2, len1) - - new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) - correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) - - return new_coors, correspond_keys diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index d6bb53c..44bd257 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,2 +1,3 @@ -from .neat import NEAT from .gene import * +from .genome import * +from .neat import NEAT diff --git a/algorithm/neat/ga/__init__.py b/algorithm/neat/ga/__init__.py index cbb0157..198f8ac 100644 --- a/algorithm/neat/ga/__init__.py +++ b/algorithm/neat/ga/__init__.py @@ -1,3 +1,2 @@ -from .crossover import crossover -from .mutate import mutate -from .operation import create_next_generation +from .crossover import BaseCrossover, DefaultCrossover +from .mutation import BaseMutation, DefaultMutation diff --git a/algorithm/neat/ga/crossover.py b/algorithm/neat/ga/crossover.py deleted file mode 100644 index 88a40bb..0000000 --- a/algorithm/neat/ga/crossover.py +++ /dev/null @@ -1,70 +0,0 @@ -import jax -from jax import Array, numpy as jnp - -from core import Genome - - -def crossover(randkey, genome1: Genome, genome2: Genome): - """ - use genome1 and genome2 to generate a new genome - notice that genome1 should have higher fitness than genome2 (genome1 is winner!) - """ - randkey_1, randkey_2, key = jax.random.split(randkey, 3) - - # crossover nodes - keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0] - # make homologous genes align in nodes2 align with nodes1 - nodes2 = align_array(keys1, keys2, genome2.nodes, False) - nodes1 = genome1.nodes - # For not homologous genes, use the value of nodes1(winner) - # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) - - # crossover connections - con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2] - conns2 = align_array(con_keys1, con_keys2, genome2.conns, True) - conns1 = genome1.conns - - new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2)) - - return genome1.update(new_nodes, new_cons) - - -def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array: - """ - After I review this code, I found that it is the most difficult part of the code. Please never change it! - make ar2 align with ar1. - :param seq1: - :param seq2: - :param ar2: - :param is_conn: - :return: - align means to intersect part of ar2 will be at the same position as ar1, - non-intersect part of ar2 will be set to Nan - """ - seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] - mask = (seq1 == seq2) & (~jnp.isnan(seq1)) - - if is_conn: - mask = jnp.all(mask, axis=2) - - intersect_mask = mask.any(axis=1) - idx = jnp.arange(0, len(seq1)) - idx_fixed = jnp.dot(mask, idx) - - refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) - - return refactor_ar2 - - -def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ - r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) diff --git a/algorithm/neat/ga/crossover/__init__.py b/algorithm/neat/ga/crossover/__init__.py new file mode 100644 index 0000000..ca6b068 --- /dev/null +++ b/algorithm/neat/ga/crossover/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseCrossover +from .default import DefaultCrossover diff --git a/algorithm/neat/ga/crossover/base.py b/algorithm/neat/ga/crossover/base.py new file mode 100644 index 0000000..9f638a2 --- /dev/null +++ b/algorithm/neat/ga/crossover/base.py @@ -0,0 +1,3 @@ +class BaseCrossover: + def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): + raise NotImplementedError \ No newline at end of file diff --git a/algorithm/neat/ga/crossover/default.py b/algorithm/neat/ga/crossover/default.py new file mode 100644 index 0000000..adabd2d --- /dev/null +++ b/algorithm/neat/ga/crossover/default.py @@ -0,0 +1,66 @@ +import jax, jax.numpy as jnp + +from .base import BaseCrossover + +class DefaultCrossover(BaseCrossover): + def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): + """ + use genome1 and genome2 to generate a new genome + notice that genome1 should have higher fitness than genome2 (genome1 is winner!) + """ + randkey_1, randkey_2, key = jax.random.split(randkey, 3) + + # crossover nodes + keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + # make homologous genes align in nodes2 align with nodes1 + nodes2 = self.align_array(keys1, keys2, nodes2, False) + + # For not homologous genes, use the value of nodes1(winner) + # For homologous genes, use the crossover result between nodes1 and nodes2 + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2)) + + # crossover connections + con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] + conns2 = self.align_array(con_keys1, con_keys2, conns2, True) + + new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2)) + + return new_nodes, new_conns + + def align_array(self, seq1, seq2, ar2, is_conn: bool): + """ + After I review this code, I found that it is the most difficult part of the code. Please never change it! + make ar2 align with ar1. + :param seq1: + :param seq2: + :param ar2: + :param is_conn: + :return: + align means to intersect part of ar2 will be at the same position as ar1, + non-intersect part of ar2 will be set to Nan + """ + seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :] + mask = (seq1 == seq2) & (~jnp.isnan(seq1)) + + if is_conn: + mask = jnp.all(mask, axis=2) + + intersect_mask = mask.any(axis=1) + idx = jnp.arange(0, len(seq1)) + idx_fixed = jnp.dot(mask, idx) + + refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) + + return refactor_ar2 + + def crossover_gene(self, rand_key, g1, g2): + """ + crossover two genes + :param rand_key: + :param g1: + :param g2: + :return: + only gene with the same key will be crossover, thus don't need to consider change key + """ + r = jax.random.uniform(rand_key, shape=g1.shape) + return jnp.where(r > 0.5, g1, g2) diff --git a/algorithm/neat/ga/mutate.py b/algorithm/neat/ga/mutate.py deleted file mode 100644 index 499d4c2..0000000 --- a/algorithm/neat/ga/mutate.py +++ /dev/null @@ -1,186 +0,0 @@ -from typing import Tuple - -import jax -from jax import Array, numpy as jnp, vmap - -from config import NeatConfig -from core import State, Gene, Genome -from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns - - -def mutate(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key): - """ - Mutate a population of genomes - """ - k1, k2 = jax.random.split(randkey) - - genome = mutate_structure(config, gene, state, k1, genome, new_node_key) - genome = mutate_values(gene, state, randkey, genome) - - return genome - - -def mutate_structure(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key): - def mutate_add_node(key_, genome_: Genome): - i_key, o_key, idx = choice_connection_key(key_, genome_.conns) - - def nothing(): - return genome_ - - def successful_add_node(): - # disable the connection - new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False)) - - # add a new node - new_genome = new_genome.add_node(new_node_key, gene.new_node_attrs(state)) - - # add two new connections - new_genome = new_genome.add_conn(i_key, new_node_key, True, gene.new_conn_attrs(state)) - new_genome = new_genome.add_conn(new_node_key, o_key, True, gene.new_conn_attrs(state)) - - return new_genome - - # if from_idx == I_INT, that means no connection exist, do nothing - return jax.lax.cond(idx == I_INT, nothing, successful_add_node) - - def mutate_delete_node(key_, genome_: Genome): - # TODO: Do we really need to delete a node? - # randomly choose a node - key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=False, allow_output_keys=False) - - def nothing(): - return genome_ - - def successful_delete_node(): - # delete the node - new_genome = genome_.delete_node_by_pos(idx) - - # delete all connections - new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None], - jnp.nan, new_genome.conns) - - return new_genome.update_conns(new_conns) - - return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - - def mutate_add_conn(key_, genome_: Genome): - # randomly choose two nodes - k1_, k2_ = jax.random.split(key_, num=2) - i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx, - allow_input_keys=False, allow_output_keys=True) - - conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key)) - - def nothing(): - return genome_ - - def successful(): - return genome_.add_conn(i_key, o_key, True, gene.new_conn_attrs(state)) - - def already_exist(): - return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True)) - - is_already_exist = conn_pos != I_INT - - if config.network_type == 'feedforward': - u_cons = unflatten_conns(genome_.nodes, genome_.conns) - cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) - is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx) - - choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - return jax.lax.switch(choice, [already_exist, nothing, successful]) - - elif config.network_type == 'recurrent': - return jax.lax.cond(is_already_exist, already_exist, successful) - - else: - raise ValueError(f"Invalid network type: {config.network_type}") - - def mutate_delete_conn(key_, genome_: Genome): - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(key_, genome_.conns) - - def nothing(): - return genome_ - - def successfully_delete_connection(): - return genome_.delete_conn_by_pos(idx) - - return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - - k1, k2, k3, k4 = jax.random.split(randkey, num=4) - r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) - - def no(k, g): - return g - - genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome) - genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome) - genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome) - genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome) - - return genome - - -def mutate_values(gene: Gene, state: State, randkey, genome: Genome): - k1, k2 = jax.random.split(randkey, num=2) - nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) - conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) - - nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:] - - new_nodes_attrs = vmap(gene.mutate_node, in_axes=(None, 0, 0))(state, nodes_keys, nodes_attrs) - new_conns_attrs = vmap(gene.mutate_conn, in_axes=(None, 0, 0))(state, conns_keys, conns_attrs) - - # nan nodes not changed - new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) - new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) - - new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs) - new_conns = genome.conns.at[:, 3:].set(new_conns_attrs) - - return genome.update(new_nodes, new_conns) - - -def choice_node_key(rand_key: Array, nodes: Array, - input_keys: Array, output_keys: Array, - allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: - """ - Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. - :param rand_key: - :param nodes: - :param input_keys: - :param output_keys: - :param allow_input_keys: - :param allow_output_keys: - :return: return its key and position(idx) - """ - - node_keys = nodes[:, 0] - mask = ~jnp.isnan(node_keys) - - if not allow_input_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) - - if not allow_output_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) - - idx = fetch_random(rand_key, mask) - key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) - return key, idx - - -def choice_connection_key(rand_key: Array, conns: Array): - """ - Randomly choose a connection key from the given connections. - :return: i_key, o_key, idx - """ - - idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0])) - i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan) - o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan) - - return i_key, o_key, idx diff --git a/algorithm/neat/ga/mutation/__init__.py b/algorithm/neat/ga/mutation/__init__.py new file mode 100644 index 0000000..599f35c --- /dev/null +++ b/algorithm/neat/ga/mutation/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseMutation +from .default import DefaultMutation \ No newline at end of file diff --git a/algorithm/neat/ga/mutation/base.py b/algorithm/neat/ga/mutation/base.py new file mode 100644 index 0000000..8291522 --- /dev/null +++ b/algorithm/neat/ga/mutation/base.py @@ -0,0 +1,3 @@ +class BaseMutation: + def __call__(self, key, genome, nodes, conns, new_node_key): + raise NotImplementedError \ No newline at end of file diff --git a/algorithm/neat/ga/mutation/default.py b/algorithm/neat/ga/mutation/default.py new file mode 100644 index 0000000..7cc1446 --- /dev/null +++ b/algorithm/neat/ga/mutation/default.py @@ -0,0 +1,201 @@ +import jax, jax.numpy as jnp +from . import BaseMutation +from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles + + +class DefaultMutation(BaseMutation): + + def __init__( + self, + conn_add: float = 0.4, + conn_delete: float = 0, + node_add: float = 0.2, + node_delete: float = 0, + ): + self.conn_add = conn_add + self.conn_delete = conn_delete + self.node_add = node_add + self.node_delete = node_delete + + def __call__(self, randkey, genome, nodes, conns, new_node_key): + k1, k2 = jax.random.split(randkey) + + nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key) + nodes, conns = self.mutate_values(k2, genome, nodes, conns) + + return nodes, conns + + def mutate_structure(self, randkey, genome, nodes, conns, new_node_key): + def mutate_add_node(key_, nodes_, conns_): + i_key, o_key, idx = self.choice_connection_key(key_, conns_) + + def successful_add_node(): + # disable the connection + new_conns = conns_.at[idx, 2].set(False) + + # add a new node + new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs()) + + # add two new connections + new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs()) + new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs()) + + return new_nodes, new_conns + + return jax.lax.cond( + idx == I_INT, + lambda: (nodes_, conns_), # do nothing + successful_add_node + ) + + def mutate_delete_node(key_, nodes_, conns_): + + # randomly choose a node + key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx, + allow_input_keys=False, allow_output_keys=False) + + def successful_delete_node(): + # delete the node + new_nodes = genome.delete_node_by_pos(nodes_, idx) + + # delete all connections + new_conns = jnp.where( + ((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None], + jnp.nan, + conns_ + ) + + return new_nodes, new_conns + + return jax.lax.cond( + idx == I_INT, + lambda: (nodes_, conns_), # do nothing + successful_delete_node + ) + + def mutate_add_conn(key_, nodes_, conns_): + # randomly choose two nodes + k1_, k2_ = jax.random.split(key_, num=2) + + # input node of the connection can be any node + i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx, + allow_input_keys=True, allow_output_keys=True) + + # output node of the connection can be any node except input node + o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx, + allow_input_keys=False, allow_output_keys=True) + + conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)) + is_already_exist = conn_pos != I_INT + + def nothing(): + return nodes_, conns_ + + def successful(): + return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs()) + + def already_exist(): + return nodes_, conns_.at[conn_pos, 2].set(True) + + if genome.network_type == 'feedforward': + u_cons = unflatten_conns(nodes_, conns_) + cons_exist = ~jnp.isnan(u_cons[0, :, :]) + is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx) + + return jax.lax.cond( + is_already_exist, + already_exist, + jax.lax.cond( + is_cycle, + nothing, + successful + ) + ) + + elif genome.network_type == 'recurrent': + return jax.lax.cond( + is_already_exist, + already_exist, + successful + ) + + else: + raise ValueError(f"Invalid network type: {genome.network_type}") + + def mutate_delete_conn(key_, nodes_, conns_): + # randomly choose a connection + i_key, o_key, idx = self.choice_connection_key(key_, conns_) + + def successfully_delete_connection(): + return nodes_, genome.delete_conn_by_pos(conns_, idx) + + return jax.lax.cond( + idx == I_INT, + lambda: (nodes_, conns_), # nothing + successfully_delete_connection + ) + + k1, k2, k3, k4 = jax.random.split(randkey, num=4) + r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) + + def no(k, g): + return g + + genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns) + genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns) + genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns) + genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns) + + return genome + + def mutate_values(self, randkey, genome, nodes, conns): + k1, k2 = jax.random.split(randkey, num=2) + nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) + conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) + + new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes) + new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns) + + # nan nodes not changed + new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) + new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns) + + return new_nodes, new_conns + + def choice_node_key(self, rand_key, nodes, input_idx, output_idx, + allow_input_keys: bool = False, allow_output_keys: bool = False): + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param rand_key: + :param nodes: + :param input_idx: + :param output_idx: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~jnp.isnan(node_keys) + + if not allow_input_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx)) + + if not allow_output_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx)) + + idx = fetch_random(rand_key, mask) + key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + return key, idx + + def choice_connection_key(self, rand_key, conns): + """ + Randomly choose a connection key from the given connections. + :return: i_key, o_key, idx + """ + + idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0])) + i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan) + + return i_key, o_key, idx diff --git a/algorithm/neat/ga/operation.py b/algorithm/neat/ga/operation.py deleted file mode 100644 index d2aff65..0000000 --- a/algorithm/neat/ga/operation.py +++ /dev/null @@ -1,40 +0,0 @@ -import jax -from jax import numpy as jnp, vmap - -from config import NeatConfig -from core import Genome, State, Gene -from .mutate import mutate -from .crossover import crossover - - -def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask): - # prepare random keys - pop_size = state.idx2species.shape[0] - new_node_keys = jnp.arange(pop_size) + state.next_node_key - - k1, k2 = jax.random.split(randkey, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner] - lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser] - n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc)) - - # batch mutation - mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0)) - m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes) - pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) - next_node_key = max_node_key + 1 - - return state.update( - pop_genomes=Genome(pop_nodes, pop_conns), - next_node_key=next_node_key, - ) diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index 49a05b9..7e8e1ce 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1,3 +1,3 @@ -from .normal import NormalGene, NormalGeneConfig -from .recurrent import RecurrentGene, RecurrentGeneConfig - +from .base import BaseGene +from .conn import * +from .node import * diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py new file mode 100644 index 0000000..1430171 --- /dev/null +++ b/algorithm/neat/gene/base.py @@ -0,0 +1,23 @@ +class BaseGene: + "Base class for node genes or connection genes." + fixed_attrs = [] + custom_attrs = [] + + def __init__(self): + pass + + def new_custom_attrs(self): + raise NotImplementedError + + def mutate(self, randkey, gene): + raise NotImplementedError + + def distance(self, gene1, gene2): + raise NotImplementedError + + def forward(self, attrs, inputs): + raise NotImplementedError + + @property + def length(self): + return len(self.fixed_attrs) + len(self.custom_attrs) \ No newline at end of file diff --git a/algorithm/neat/gene/conn/__init__.py b/algorithm/neat/gene/conn/__init__.py new file mode 100644 index 0000000..8553a07 --- /dev/null +++ b/algorithm/neat/gene/conn/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseConnGene +from .default import DefaultConnGene diff --git a/algorithm/neat/gene/conn/base.py b/algorithm/neat/gene/conn/base.py new file mode 100644 index 0000000..b7b3bdc --- /dev/null +++ b/algorithm/neat/gene/conn/base.py @@ -0,0 +1,12 @@ +from .. import BaseGene + + +class BaseConnGene(BaseGene): + "Base class for connection genes." + fixed_attrs = ['input_index', 'output_index', 'enabled'] + + def __init__(self): + super().__init__() + + def forward(self, attrs, inputs): + raise NotImplementedError diff --git a/algorithm/neat/gene/conn/default.py b/algorithm/neat/gene/conn/default.py new file mode 100644 index 0000000..f5915c6 --- /dev/null +++ b/algorithm/neat/gene/conn/default.py @@ -0,0 +1,51 @@ +import jax.numpy as jnp + +from utils import mutate_float +from . import BaseConnGene + + +class DefaultConnGene(BaseConnGene): + "Default connection gene, with the same behavior as in NEAT-python." + + fixed_attrs = ['input_index', 'output_index', 'enabled'] + attrs = ['weight'] + + def __init__( + self, + weight_init_mean: float = 0.0, + weight_init_std: float = 1.0, + weight_mutate_power: float = 0.5, + weight_mutate_rate: float = 0.8, + weight_replace_rate: float = 0.1, + ): + super().__init__() + self.weight_init_mean = weight_init_mean + self.weight_init_std = weight_init_std + self.weight_mutate_power = weight_mutate_power + self.weight_mutate_rate = weight_mutate_rate + self.weight_replace_rate = weight_replace_rate + + def new_custom_attrs(self): + return jnp.array([self.weight_init_mean]) + + def mutate(self, key, conn): + input_index = conn[0] + output_index = conn[1] + enabled = conn[2] + weight = mutate_float(key, + conn[3], + self.weight_init_mean, + self.weight_init_std, + self.weight_mutate_power, + self.weight_mutate_rate, + self.weight_replace_rate + ) + + return jnp.array([input_index, output_index, enabled, weight]) + + def distance(self, attrs1, attrs2): + return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight + + def forward(self, attrs, inputs): + weight = attrs[0] + return inputs * weight diff --git a/algorithm/neat/gene/node/__init__.py b/algorithm/neat/gene/node/__init__.py new file mode 100644 index 0000000..b88d714 --- /dev/null +++ b/algorithm/neat/gene/node/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseNodeGene +from .default import DefaultNodeGene diff --git a/algorithm/neat/gene/node/base.py b/algorithm/neat/gene/node/base.py new file mode 100644 index 0000000..465050a --- /dev/null +++ b/algorithm/neat/gene/node/base.py @@ -0,0 +1,12 @@ +from .. import BaseGene + + +class BaseNodeGene(BaseGene): + "Base class for node genes." + fixed_attrs = ["index"] + + def __init__(self): + super().__init__() + + def forward(self, attrs, inputs): + raise NotImplementedError diff --git a/algorithm/neat/gene/node/default.py b/algorithm/neat/gene/node/default.py new file mode 100644 index 0000000..6580072 --- /dev/null +++ b/algorithm/neat/gene/node/default.py @@ -0,0 +1,96 @@ +from typing import Tuple + +import jax, jax.numpy as jnp + +from utils import Act, Agg, act, agg, mutate_int, mutate_float +from . import BaseNodeGene + + +class DefaultNodeGene(BaseNodeGene): + "Default node gene, with the same behavior as in NEAT-python." + + fixed_attrs = ['index'] + custom_attrs = ['bias', 'response', 'aggregation', 'activation'] + + def __init__( + self, + bias_init_mean: float = 0.0, + bias_init_std: float = 1.0, + bias_mutate_power: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + + response_init_mean: float = 1.0, + response_init_std: float = 0.0, + response_mutate_power: float = 0.5, + response_mutate_rate: float = 0.7, + response_replace_rate: float = 0.1, + + activation_default: callable = Act.sigmoid, + activation_options: Tuple = (Act.sigmoid,), + activation_replace_rate: float = 0.1, + + aggregation_default: callable = Agg.sum, + aggregation_options: Tuple = (Agg.sum,), + aggregation_replace_rate: float = 0.1, + ): + super().__init__() + self.bias_init_mean = bias_init_mean + self.bias_init_std = bias_init_std + self.bias_mutate_power = bias_mutate_power + self.bias_mutate_rate = bias_mutate_rate + self.bias_replace_rate = bias_replace_rate + + self.response_init_mean = response_init_mean + self.response_init_std = response_init_std + self.response_mutate_power = response_mutate_power + self.response_mutate_rate = response_mutate_rate + self.response_replace_rate = response_replace_rate + + self.activation_default = activation_options.index(activation_default) + self.activation_options = activation_options + self.activation_indices = jnp.arange(len(activation_options)) + self.activation_replace_rate = activation_replace_rate + + self.aggregation_default = aggregation_options.index(aggregation_default) + self.aggregation_options = aggregation_options + self.aggregation_indices = jnp.arange(len(aggregation_options)) + self.aggregation_replace_rate = aggregation_replace_rate + + def new_custom_attrs(self): + return jnp.array( + [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] + ) + + def mutate(self, key, node): + k1, k2, k3, k4 = jax.random.split(key, num=4) + index = node[0] + + bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std, + self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate) + + res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std, + self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate) + + act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate) + + agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate) + + return jnp.array([index, bias, res, act, agg]) + + def distance(self, node1, node2): + return ( + jnp.abs(node1[1] - node2[1]) + + jnp.abs(node1[2] - node2[2]) + + node1[3] != node2[3] + + node1[4] != node2[4] + ) + + def forward(self, attrs, inputs): + bias, res, act_idx, agg_idx = attrs + + z = agg(agg_idx, inputs, self.aggregation_options) + z = bias + res * z + z = act(act_idx, z, self.activation_options) + + return z diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py deleted file mode 100644 index fcff108..0000000 --- a/algorithm/neat/gene/normal.py +++ /dev/null @@ -1,210 +0,0 @@ -from dataclasses import dataclass -from typing import Tuple - -import jax -from jax import Array, numpy as jnp - -from config import GeneConfig -from core import Gene, Genome, State -from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg - - -@dataclass(frozen=True) -class NormalGeneConfig(GeneConfig): - bias_init_mean: float = 0.0 - bias_init_std: float = 1.0 - bias_mutate_power: float = 0.5 - bias_mutate_rate: float = 0.7 - bias_replace_rate: float = 0.1 - - response_init_mean: float = 1.0 - response_init_std: float = 0.0 - response_mutate_power: float = 0.5 - response_mutate_rate: float = 0.7 - response_replace_rate: float = 0.1 - - activation_default: callable = Act.sigmoid - activation_options: Tuple = (Act.sigmoid, ) - activation_replace_rate: float = 0.1 - - aggregation_default: callable = Agg.sum - aggregation_options: Tuple = (Agg.sum, ) - aggregation_replace_rate: float = 0.1 - - weight_init_mean: float = 0.0 - weight_init_std: float = 1.0 - weight_mutate_power: float = 0.5 - weight_mutate_rate: float = 0.8 - weight_replace_rate: float = 0.1 - - def __post_init__(self): - assert self.bias_init_std >= 0.0 - assert self.bias_mutate_power >= 0.0 - assert self.bias_mutate_rate >= 0.0 - assert self.bias_replace_rate >= 0.0 - - assert self.response_init_std >= 0.0 - assert self.response_mutate_power >= 0.0 - assert self.response_mutate_rate >= 0.0 - assert self.response_replace_rate >= 0.0 - - assert self.activation_default == self.activation_options[0] - assert self.aggregation_default == self.aggregation_options[0] - - -class NormalGene(Gene): - node_attrs = ['bias', 'response', 'aggregation', 'activation'] - conn_attrs = ['weight'] - - def __init__(self, config: NormalGeneConfig = NormalGeneConfig()): - self.config = config - - def setup(self, state: State = State()): - return state.update( - bias_init_mean=self.config.bias_init_mean, - bias_init_std=self.config.bias_init_std, - bias_mutate_power=self.config.bias_mutate_power, - bias_mutate_rate=self.config.bias_mutate_rate, - bias_replace_rate=self.config.bias_replace_rate, - - response_init_mean=self.config.response_init_mean, - response_init_std=self.config.response_init_std, - response_mutate_power=self.config.response_mutate_power, - response_mutate_rate=self.config.response_mutate_rate, - response_replace_rate=self.config.response_replace_rate, - - activation_replace_rate=self.config.activation_replace_rate, - activation_default=0, - activation_options=jnp.arange(len(self.config.activation_options)), - - aggregation_replace_rate=self.config.aggregation_replace_rate, - aggregation_default=0, - aggregation_options=jnp.arange(len(self.config.aggregation_options)), - - weight_init_mean=self.config.weight_init_mean, - weight_init_std=self.config.weight_init_std, - weight_mutate_power=self.config.weight_mutate_power, - weight_mutate_rate=self.config.weight_mutate_rate, - weight_replace_rate=self.config.weight_replace_rate, - ) - - def update(self, state): - return state - - def new_node_attrs(self, state): - return jnp.array([state.bias_init_mean, state.response_init_mean, - state.activation_default, state.aggregation_default]) - - def new_conn_attrs(self, state): - return jnp.array([state.weight_init_mean]) - - def mutate_node(self, state, key, attrs: Array): - k1, k2, k3, k4 = jax.random.split(key, num=4) - - bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std, - state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate) - res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std, - state.response_mutate_power, state.response_mutate_rate, - state.response_replace_rate) - act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate) - agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate) - - return jnp.array([bias, res, act, agg]) - - def mutate_conn(self, state, key, attrs: Array): - weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std, - state.weight_mutate_power, state.weight_mutate_rate, - state.weight_replace_rate) - - return jnp.array([weight]) - - def distance_node(self, state, node1: Array, node2: Array): - # bias + response + activation + aggregation - return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \ - (node1[3] != node2[3]) + (node1[4] != node2[4]) - - def distance_conn(self, state, con1: Array, con2: Array): - return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight - - def forward_transform(self, state: State, genome: Genome): - u_conns = unflatten_conns(genome.nodes, genome.conns) - conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - - # remove enable attr - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - seqs = topological_sort(genome.nodes, conn_enable) - - return seqs, genome.nodes, u_conns - - def forward(self, state: State, inputs, transformed): - cal_seqs, nodes, cons = transformed - - input_idx = state.input_idx - output_idx = state.output_idx - - N = nodes.shape[0] - ini_vals = jnp.full((N,), jnp.nan) - ini_vals = ini_vals.at[input_idx].set(inputs) - - weights = cons[0, :] - - def cond_fun(carry): - values, idx = carry - return (idx < N) & (cal_seqs[idx] != I_INT) - - def body_func(carry): - values, idx = carry - i = cal_seqs[idx] - - def hit(): - ins = values * weights[:, i] - z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) - z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) - - new_values = values.at[i].set(z) - return new_values - - def miss(): - return values - - # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) - - return values, idx + 1 - - vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) - - return vals[output_idx] - - @staticmethod - def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): - k1, k2, k3 = jax.random.split(key, num=3) - noise = jax.random.normal(k1, ()) * mutate_power - replace = jax.random.normal(k2, ()) * init_std + init_mean - r = jax.random.uniform(k3, ()) - - val = jnp.where( - r < mutate_rate, - val + noise, - jnp.where( - (mutate_rate < r) & (r < mutate_rate + replace_rate), - replace, - val - ) - ) - - return val - - @staticmethod - def _mutate_int(key, val, options, replace_rate): - k1, k2 = jax.random.split(key, num=2) - r = jax.random.uniform(k1, ()) - - val = jnp.where( - r < replace_rate, - jax.random.choice(k2, options), - val - ) - - return val diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py deleted file mode 100644 index 632eea7..0000000 --- a/algorithm/neat/gene/recurrent.py +++ /dev/null @@ -1,57 +0,0 @@ -from dataclasses import dataclass - -import jax -from jax import numpy as jnp, vmap - -from .normal import NormalGene, NormalGeneConfig -from core import State, Genome -from utils import unflatten_conns, act, agg - - -@dataclass(frozen=True) -class RecurrentGeneConfig(NormalGeneConfig): - activate_times: int = 10 - - def __post_init__(self): - super().__post_init__() - assert self.activate_times > 0 - - -class RecurrentGene(NormalGene): - - def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()): - self.config = config - super().__init__(config) - - def forward_transform(self, state: State, genome: Genome): - u_conns = unflatten_conns(genome.nodes, genome.conns) - - # remove un-enable connections and remove enable attr - conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - - return genome.nodes, u_conns - - def forward(self, state: State, inputs, transformed): - nodes, conns = transformed - - batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None)) - - input_idx = state.input_idx - output_idx = state.output_idx - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - weights = conns[0, :] - - def body_func(i, values): - values = values.at[input_idx].set(inputs) - nodes_ins = values * weights.T - values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z) - return values - - vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals) - return vals[output_idx] diff --git a/algorithm/neat/genome/__init__.py b/algorithm/neat/genome/__init__.py new file mode 100644 index 0000000..0a636dd --- /dev/null +++ b/algorithm/neat/genome/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseGenome +from .default import DefaultGenome +from .recurrent import RecurrentGenome diff --git a/algorithm/neat/genome/base.py b/algorithm/neat/genome/base.py new file mode 100644 index 0000000..4567193 --- /dev/null +++ b/algorithm/neat/genome/base.py @@ -0,0 +1,66 @@ +import jax.numpy as jnp +from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene +from utils import fetch_first + + +class BaseGenome: + + network_type = None + + def __init__( + self, + num_inputs: int, + num_outputs: int, + max_nodes: int, + max_conns: int, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), + ): + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.input_idx = jnp.arange(num_inputs) + self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs) + self.max_nodes = max_nodes + self.max_conns = max_conns + self.node_gene = node_gene + self.conn_gene = conn_gene + + def transform(self, nodes, conns): + raise NotImplementedError + + def forward(self, inputs, transformed): + raise NotImplementedError + + def add_node(self, nodes, new_key: int, attrs): + """ + Add a new node to the genome. + The new node will place at the first NaN row. + """ + exist_keys = nodes[:, 0] + pos = fetch_first(jnp.isnan(exist_keys)) + new_nodes = nodes.at[pos, 0].set(new_key) + return new_nodes.at[pos, 1:].set(attrs) + + def delete_node_by_pos(self, nodes, pos): + """ + Delete a node from the genome. + Delete the node by its pos in nodes. + """ + return nodes.at[pos].set(jnp.nan) + + def add_conn(self, conns, i_key, o_key, enable: bool, attrs): + """ + Add a new connection to the genome. + The new connection will place at the first NaN row. + """ + con_keys = conns[:, 0] + pos = fetch_first(jnp.isnan(con_keys)) + new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) + return new_conns.at[pos, 3:].set(attrs) + + def delete_conn_by_pos(self, conns, pos): + """ + Delete a connection from the genome. + Delete the connection by its idx. + """ + return conns.at[pos].set(jnp.nan) diff --git a/algorithm/neat/genome/default.py b/algorithm/neat/genome/default.py new file mode 100644 index 0000000..786d8df --- /dev/null +++ b/algorithm/neat/genome/default.py @@ -0,0 +1,75 @@ +import jax, jax.numpy as jnp +from utils import unflatten_conns, topological_sort, I_INT + +from . import BaseGenome +from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene + + +class DefaultGenome(BaseGenome): + """Default genome class, with the same behavior as the NEAT-Python""" + + network_type = 'feedforward' + + def __init__(self, + num_inputs: int, + num_outputs: int, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), + ): + super().__init__(num_inputs, num_outputs, node_gene, conn_gene) + + def transform(self, nodes, conns): + u_conns = unflatten_conns(nodes, conns) + + # DONE: Seems like there is a bug in this line + # conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) + # modified: exist conn and enable is true + # conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False) + # advanced modified: when and only when enabled is True + conn_enable = u_conns[0] == 1 + + # remove enable attr + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + seqs = topological_sort(nodes, conn_enable) + + return seqs, nodes, u_conns + + def forward(self, inputs, transformed): + cal_seqs, nodes, conns = transformed + + N = nodes.shape[0] + ini_vals = jnp.full((N,), jnp.nan) + ini_vals = ini_vals.at[self.input_idx].set(inputs) + nodes_attrs = nodes[:, 1:] + + def cond_fun(carry): + values, idx = carry + return (idx < N) & (cal_seqs[idx] != I_INT) + + def body_func(carry): + values, idx = carry + i = cal_seqs[idx] + + def hit(): + ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) + # ins = values * weights[:, i] + + z = self.node_gene.forward(nodes_attrs[i], ins) + # z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) + # z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias + # z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) + + new_values = values.at[i].set(z) + return new_values + + def miss(): + return values + + # the val of input nodes is obtained by the task, not by calculation + values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit) + + return values, idx + 1 + + vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) + + return vals[self.output_idx] diff --git a/algorithm/neat/genome/recurrent.py b/algorithm/neat/genome/recurrent.py new file mode 100644 index 0000000..b1cef54 --- /dev/null +++ b/algorithm/neat/genome/recurrent.py @@ -0,0 +1,58 @@ +import jax, jax.numpy as jnp +from utils import unflatten_conns + +from . import BaseGenome +from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene + + +class RecurrentGenome(BaseGenome): + """Default genome class, with the same behavior as the NEAT-Python""" + + network_type = 'recurrent' + + def __init__(self, + num_inputs: int, + num_outputs: int, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), + activate_time: int = 10, + ): + super().__init__(num_inputs, num_outputs, node_gene, conn_gene) + self.activate_time = activate_time + + def transform(self, nodes, conns): + u_conns = unflatten_conns(nodes, conns) + + # remove un-enable connections and remove enable attr + conn_enable = u_conns[0] == 1 + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + + return nodes, u_conns + + def forward(self, inputs, transformed): + nodes, conns = transformed + + N = nodes.shape[0] + vals = jnp.full((N,), jnp.nan) + nodes_attrs = nodes[:, 1:] + + def body_func(_, values): + # set input values + values = values.at[self.input_idx].set(inputs) + + # calculate connections + node_ins = jax.vmap( + jax.vmap( + self.conn_gene.forward, + in_axes=(1, None) + ), + in_axes=(1, 0) + )(conns, values) + + # calculate nodes + values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T) + return values + + vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) + + return vals[self.output_idx] diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 330d819..d7c97f8 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -1,87 +1,38 @@ -from typing import Type +import jax, jax.numpy as jnp +from utils import State +from .. import BaseAlgorithm +from .genome import * +from .species import * +from .ga import * -import jax -from jax import numpy as jnp -import numpy as np +class NEAT(BaseAlgorithm): -from config import Config -from core import Algorithm, State, Gene, Genome -from .ga import create_next_generation -from .species import SpeciesInfo, update_species, speciate + def __init__( + self, + genome: BaseGenome, + species: BaseSpecies, + mutation: BaseMutation = DefaultMutation(), + crossover: BaseCrossover = DefaultCrossover(), + ): + self.genome = genome + self.species = species + self.mutation = mutation + self.crossover = crossover - -class NEAT(Algorithm): - - def __init__(self, config: Config, gene_type: Type[Gene]): - self.config = config - self.gene = gene_type(config.gene) - - self.forward_func = None - self.tell_func = None - - def setup(self, randkey, state: State = State()): - """initialize the state of the algorithm""" - - input_idx = np.arange(self.config.neat.inputs) - output_idx = np.arange(self.config.neat.inputs, - self.config.neat.inputs + self.config.neat.outputs) - - state = state.update( - P=self.config.basic.pop_size, - N=self.config.neat.max_nodes, - C=self.config.neat.max_conns, - S=self.config.neat.max_species, - NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes - CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes - max_stagnation=self.config.neat.max_stagnation, - species_elitism=self.config.neat.species_elitism, - spawn_number_change_rate=self.config.neat.spawn_number_change_rate, - genome_elitism=self.config.neat.genome_elitism, - survival_threshold=self.config.neat.survival_threshold, - compatibility_threshold=self.config.neat.compatibility_threshold, - compatibility_disjoint=self.config.neat.compatibility_disjoint, - compatibility_weight=self.config.neat.compatibility_weight, - - input_idx=input_idx, - output_idx=output_idx, + def setup(self, randkey): + k1, k2 = jax.random.split(randkey, 2) + return State( + randkey=k1, + generation=0, + next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2, + # inputs nodes, output nodes, 1 hidden node + species=self.species.setup(k2), ) - state = self.gene.setup(state) - pop_genomes = self._initialize_genomes(state) - - species_info = SpeciesInfo.initialize(state) - idx2species = jnp.zeros(state.P, dtype=jnp.float32) - - center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32) - center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32) - center_genomes = Genome(center_nodes, center_conns) - center_genomes = center_genomes.set(0, pop_genomes[0]) - - generation = 0 - next_node_key = max(*state.input_idx, *state.output_idx) + 2 - next_species_key = 1 - - state = state.update( - randkey=randkey, - pop_genomes=pop_genomes, - species_info=species_info, - idx2species=idx2species, - center_genomes=center_genomes, - - # avoid jax auto cast from int to float. that would cause re-compilation. - generation=jnp.asarray(generation, dtype=jnp.int32), - next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32), - next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32), - ) - - return jax.device_put(state) - - def ask_algorithm(self, state: State): - return state.pop_genomes - - def tell_algorithm(self, state: State, fitness): - state = self.gene.update(state) + def ask(self, state: State): + return self.species.ask(state) + def tell(self, state: State, fitness): k1, k2, randkey = jax.random.split(state.randkey, 3) state = state.update( @@ -89,46 +40,55 @@ class NEAT(Algorithm): randkey=randkey ) - state, winner, loser, elite_mask = update_species(state, k1, fitness) + state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation) - state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask) + state = self.create_next_generation(k2, state, winner, loser, elite_mask) - state = speciate(self.gene, state) + state = self.species.speciate(state, state.generation) return state - def forward_transform(self, state: State, genome: Genome): - return self.gene.forward_transform(state, genome) + def transform(self, state: State): + """transform the genome into a neural network""" + raise NotImplementedError - def forward(self, state: State, inputs, genome: Genome): - return self.gene.forward(state, inputs, genome) + def forward(self, inputs, transformed): + raise NotImplementedError - def _initialize_genomes(self, state): - o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes - o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections + def create_next_generation(self, randkey, state, winner, loser, elite_mask): + # prepare random keys + pop_size = self.species.pop_size + new_node_keys = jnp.arange(pop_size) + state.species.next_node_key - input_idx = state.input_idx - output_idx = state.output_idx - new_node_key = max([*input_idx, *output_idx]) + 1 + k1, k2 = jax.random.split(randkey, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) - o_nodes[input_idx, 0] = input_idx - o_nodes[output_idx, 0] = output_idx - o_nodes[new_node_key, 0] = new_node_key - o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state) - o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state) + wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner] + lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser] - input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] - o_conns[input_idx, 0:2] = input_conns # in key, out key - o_conns[input_idx, 2] = True # enabled - o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state) + # batch crossover + n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0)) + (crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc)) - output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] - o_conns[output_idx, 0:2] = output_conns # in key, out key - o_conns[output_idx, 2] = True # enabled - o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state) + # batch mutation + m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0)) + (mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys)) - # repeat origin genome for P times to create population - pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) - pop_conns = np.tile(o_conns, (state.P, 1, 1)) + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) + pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns) + + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return state.update( + species=state.species.update( + pop_nodes=pop_nodes, + pop_conns=pop_conns, + ), + next_node_key=next_node_key, + ) - return Genome(pop_nodes, pop_conns) diff --git a/algorithm/neat/species/__init__.py b/algorithm/neat/species/__init__.py index eda012f..f52a178 100644 --- a/algorithm/neat/species/__init__.py +++ b/algorithm/neat/species/__init__.py @@ -1,2 +1,2 @@ -from .species_info import SpeciesInfo -from .operations import update_species, speciate +from .base import BaseSpecies +from .default import DefaultSpecies diff --git a/algorithm/neat/species/base.py b/algorithm/neat/species/base.py new file mode 100644 index 0000000..f1294f2 --- /dev/null +++ b/algorithm/neat/species/base.py @@ -0,0 +1,14 @@ +from utils import State + +class BaseSpecies: + def setup(self, randkey): + raise NotImplementedError + + def ask(self, state: State): + raise NotImplementedError + + def update_species(self, state, fitness, generation): + raise NotImplementedError + + def speciate(self, state, generation): + raise NotImplementedError \ No newline at end of file diff --git a/algorithm/neat/species/default.py b/algorithm/neat/species/default.py new file mode 100644 index 0000000..c04e653 --- /dev/null +++ b/algorithm/neat/species/default.py @@ -0,0 +1,514 @@ +import numpy as np +import jax, jax.numpy as jnp +from utils import State, rank_elements, argmin_with_mask, fetch_first +from ..genome import BaseGenome + + +class DefaultSpecies: + + def __init__(self, + genome: BaseGenome, + pop_size, + species_size, + compatibility_disjoint: float = 1.0, + compatibility_weight: float = 0.4, + max_stagnation: int = 15, + species_elitism: int = 2, + spawn_number_change_rate: float = 0.5, + genome_elitism: int = 2, + survival_threshold: float = 0.2, + min_species_size: int = 1, + compatibility_threshold: float = 3.5 + ): + + self.genome = genome + self.pop_size = pop_size + self.species_size = species_size + + self.compatibility_disjoint = compatibility_disjoint + self.compatibility_weight = compatibility_weight + self.max_stagnation = max_stagnation + self.species_elitism = species_elitism + self.spawn_number_change_rate = spawn_number_change_rate + self.genome_elitism = genome_elitism + self.survival_threshold = survival_threshold + self.min_species_size = min_species_size + self.compatibility_threshold = compatibility_threshold + + self.species_arange = jnp.arange(self.species_size) + + def setup(self, randkey): + pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome) + + species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species + best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species + last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved + member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species + idx2species = jnp.zeros(self.pop_size) # the species index of each individual + + # nodes for each center genome of each species + center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan) + + # connections for each center genome of each species + center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan) + + species_keys = species_keys.at[0].set(0) + best_fitness = best_fitness.at[0].set(-jnp.inf) + last_improved = last_improved.at[0].set(0) + member_count = member_count.at[0].set(self.pop_size) + center_nodes = center_nodes.at[0].set(pop_nodes[0]) + center_conns = center_conns.at[0].set(pop_conns[0]) + + return State( + randkey=randkey, + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, + idx2species=idx2species, + center_nodes=center_nodes, + center_conns=center_conns, + next_species_key=1, # 0 is reserved for the first species + ) + + def ask(self, state): + return state.pop_nodes, state.pop_conns + + def update_species(self, state, fitness, generation): + # update the fitness of each species + species_fitness = self.update_species_fitness(state, fitness) + + # stagnation species + state, species_fitness = self.stagnation(state, generation, species_fitness) + + # sort species_info by their fitness. (also push nan to the end) + sort_indices = jnp.argsort(species_fitness)[::-1] + state = state.update( + species_keys=state.species_keys[sort_indices], + best_fitness=state.best_fitness[sort_indices], + last_improved=state.last_improved[sort_indices], + member_count=state.member_count[sort_indices], + center_nodes=state.center_nodes[sort_indices], + center_conns=state.center_conns[sort_indices], + ) + + # decide the number of members of each species by their fitness + spawn_number = self.cal_spawn_numbers(state) + + k1, k2 = jax.random.split(state.randkey) + # crossover info + winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness) + + return state(randkey=k2), winner, loser, elite_mask + + def update_species_fitness(self, state, fitness): + """ + obtain the fitness of the species by the fitness of each individual. + use max criterion. + """ + + def aux_func(idx): + s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf) + val = jnp.max(s_fitness) + return val + + return jax.vmap(aux_func)(self.species_arange) + + def stagnation(self, state, generation, species_fitness): + """ + stagnation species. + those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. + elitism species never stagnation + + generation: the current generation + """ + + def check_stagnation(idx): + # determine whether the species stagnation + st = ( + (species_fitness[idx] <= state.best_fitness[ + idx]) & # not better than the best fitness of the species + (generation - state.last_improved[idx] > self.max_stagnation) # for a long time + ) + + # update last_improved and best_fitness + li, bf = jax.lax.cond( + species_fitness[idx] > state.best_fitness[idx], + lambda: (generation, species_fitness[idx]), # update + lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update + ) + + return st, bf, li + + spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange) + + # elite species will not be stagnation + species_rank = rank_elements(species_fitness) + spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation + + # set stagnation species to nan + def update_func(idx): + return jax.lax.cond( + spe_st[idx], + lambda: ( + jnp.nan, # species_key + jnp.nan, # best_fitness + jnp.nan, # last_improved + jnp.nan, # member_count + -jnp.inf, # species_fitness + jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes + jnp.full_like(center_conns[idx], jnp.nan), # center_conns + ), # stagnation species + lambda: ( + species_keys[idx], + best_fitness[idx], + last_improved[idx], + state.member_count[idx], + species_fitness[idx], + center_nodes[idx], + center_conns[idx] + ) # not stagnation species + ) + + ( + species_keys, + best_fitness, + last_improved, + member_count, + species_fitness, + center_nodes, + center_conns + ) = ( + jax.vmap(update_func)(self.species_arange)) + + return state.update( + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, + center_nodes=center_nodes, + center_conns=center_conns, + ), species_fitness + + def cal_spawn_numbers(self, state): + """ + decide the number of members of each species by their fitness rank. + the species with higher fitness will have more members + Linear ranking selection + e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] + """ + + species_keys = state.species_keys + + is_species_valid = ~jnp.isnan(species_keys) + valid_species_num = jnp.sum(is_species_valid) + denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + + rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1] + spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] + spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + + target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member + + # Avoid too much variation of numbers for a species + previous_size = state.member_count + spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate + spawn_number = spawn_number.astype(jnp.int32) + + # must control the sum of spawn_number to be equal to pop_size + error = state.P - jnp.sum(spawn_number) + + # add error to the first species to control the sum of spawn_number + spawn_number = spawn_number.at[0].add(error) + + return spawn_number + + def create_crossover_pair(self, state, randkey, spawn_number, fitness): + s_idx = self.species_arange + p_idx = jnp.arange(self.pop_size) + + def aux_func(key, idx): + members = state.idx2species == state.species_keys[idx] + members_num = jnp.sum(members) + + members_fitness = jnp.where(members, fitness, -jnp.inf) + sorted_member_indices = jnp.argsort(members_fitness)[::-1] + + survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32) + + select_pro = (p_idx < survive_size) / survive_size + fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro) + + # elite + fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa) + ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma) + elite = jnp.where(p_idx < self.genome_elitism, True, False) + return fa, ma, elite + + fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx) + + spawn_number_cum = jnp.cumsum(spawn_number) + + def aux_func(idx): + loc = jnp.argmax(idx < spawn_number_cum) + + # elite genomes are at the beginning of the species + idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) + return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + + part1, part2, elite_mask = jax.vmap(aux_func)(p_idx) + + is_part1_win = fitness[part1] >= fitness[part2] + winner = jnp.where(is_part1_win, part1, part2) + loser = jnp.where(is_part1_win, part2, part1) + + return winner, loser, elite_mask + + def speciate(self, state, generation): + # prepare distance functions + o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population + + # idx to specie key + idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species + + # the distance between genomes to its center genomes + o2c_distances = jnp.full((self.pop_size,), jnp.inf) + + # step 1: find new centers + def cond_func(carry): + # i, idx2species, center_nodes, center_conns, o2c_distances + i, i2s, cns, ccs, o2c = carry + + return ( + (i < self.species_size) & + (~jnp.isnan(state.species_keys[i])) + ) # current species is existing + + def body_func(carry): + i, i2s, cns, ccs, o2c = carry + + distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns) + + # find the closest one + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) + + i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) + cns = cns.set(i, state.pop_nodes[closest_idx]) + ccs = ccs.set(i, state.pop_conns[closest_idx]) + + # the genome with closest_idx will become the new center, thus its distance to center is 0. + o2c = o2c.at[closest_idx].set(0) + + return i + 1, i2s, cns, ccs, o2c + + _, idx2species, center_nodes, center_conns, o2c_distances = \ + jax.lax.while_loop(cond_func, body_func, + (0, idx2species, state.center_nodes, state.center_conns, o2c_distances)) + + state = state.update( + idx2species=idx2species, + center_nodes=center_nodes, + center_conns=center_conns, + ) + + # part 2: assign members to each species + def cond_func(carry): + # i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key + i, i2s, cns, ccs, sk, o2c, nsk = carry + + current_species_existed = ~jnp.isnan(sk[i]) + not_all_assigned = jnp.any(jnp.isnan(i2s)) + not_reach_species_upper_bounds = i < self.species_size + return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + + def body_func(carry): + i, i2s, cns, ccs, sk, o2c, nsk = carry + + _, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond( + jnp.isnan(sk[i]), # whether the current species is existing or not + create_new_species, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cns, ccs, sk, o2c, nsk) + ) + + return i + 1, i2s, cns, ccs, sk, o2c, nsk + + def create_new_species(carry): + i, i2s, cns, ccs, sk, o2c, nsk = carry + + # pick the first one who has not been assigned to any species + idx = fetch_first(jnp.isnan(i2s)) + + # assign it to the new species + # [key, best score, last update generation, member_count] + sk = sk.at[i].set(nsk) # nsk -> next species key + i2s = i2s.at[idx].set(nsk) + o2c = o2c.at[idx].set(0) + + # update center genomes + cns = cns.set(i, state.pop_nodes[idx]) + ccs = ccs.set(i, state.pop_conns[idx]) + + # find the members for the new species + i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c) + + return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, cns, ccs, sk, o2c, nsk = carry + + i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c) + + # turn to next species + return i + 1, i2s, cns, ccs, sk, o2c, nsk + + def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c): + # distance between such center genome and ppo genomes + o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns) + + close_enough_mask = o2p_distance < self.compatibility_threshold + # when a genome is not assigned or the distance between its current center is bigger than this center + catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) + + mask = close_enough_mask & catchable_mask + + # update species info + i2s = jnp.where(mask, sk[i], i2s) + + # update distance between centers + o2c = jnp.where(mask, o2p_distance, o2c) + + return i2s, o2c + + # update idx2species + _, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop( + cond_func, + body_func, + (0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances, + state.next_species_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition can only happen when the number of species is reached species upper bounds + idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) + + # complete info of species which is created in this generation + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) + best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) + last_improved = jnp.where(new_created_mask, generation, state.last_improved) + + # update members count + def count_members(idx): + return jax.lax.cond( + jnp.isnan(species_keys[idx]), # if the species is not existing + lambda _: jnp.nan, # nan + lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members + ) + + member_count = jax.vmap(count_members)(self.species_arange) + + return state.update( + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, + idx2species=idx2species, + center_nodes=center_nodes, + center_conns=center_conns, + next_species_key=next_species_key + ) + + def distance(self, nodes1, conns1, nodes2, conns2): + """ + The distance between two genomes + """ + return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) + + def node_distance(self, nodes1, nodes2): + """ + The distance of the nodes part for two genomes + """ + node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) + node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) + max_cnt = jnp.maximum(node_cnt1, node_cnt2) + + # align homologous nodes + # this process is similar to np.intersect1d. + nodes = jnp.concatenate((nodes1, nodes2), axis=0) + keys = nodes[:, 0] + sorted_indices = jnp.argsort(keys, axis=0) + nodes = nodes[sorted_indices] + nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = nodes[:-1], nodes[1:] # first row, second row + + # flag location of homologous nodes + intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) + + # calculate the count of non_homologous of two genomes + non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) + + # calculate the distance of homologous nodes + hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr) + hnd = jnp.where(jnp.isnan(hnd), 0, hnd) + homologous_distance = jnp.sum(hnd * intersect_mask) + + val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight + + return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division + + def conn_distance(self, conns1, conns2): + """ + The distance of the conns part for two genomes + """ + con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0])) + con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0])) + max_cnt = jnp.maximum(con_cnt1, con_cnt2) + + cons = jnp.concatenate((conns1, conns2), axis=0) + keys = cons[:, :2] + sorted_indices = jnp.lexsort(keys.T[::-1]) + cons = cons[sorted_indices] + cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + fr, sr = cons[:-1], cons[1:] # first row, second row + + # both genome has such connection + intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) + + non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) + hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr) + hcd = jnp.where(jnp.isnan(hcd), 0, hcd) + homologous_distance = jnp.sum(hcd * intersect_mask) + + val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight + + return jnp.where(max_cnt == 0, 0, val / max_cnt) + + +def initialize_population(pop_size, genome): + o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes + o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections + + input_idx, output_idx = genome.input_idx, genome.output_idx + new_node_key = max([*input_idx, *output_idx]) + 1 + + o_nodes[input_idx, 0] = genome.input_idx + o_nodes[output_idx, 0] = genome.output_idx + o_nodes[new_node_key, 0] = new_node_key # one hidden node + o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs() + o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node + + input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden + o_conns[input_idx, 0:2] = input_conns # in key, out key + o_conns[input_idx, 2] = True # enabled + o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs() + + output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes + o_conns[output_idx, 0:2] = output_conns # in key, out key + o_conns[output_idx, 2] = True # enabled + o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs() + + # repeat origin genome for P times to create population + pop_nodes = np.tile(o_nodes, (pop_size, 1, 1)) + pop_conns = np.tile(o_conns, (pop_size, 1, 1)) + + return pop_nodes, pop_conns diff --git a/algorithm/neat/species/distance.py b/algorithm/neat/species/distance.py deleted file mode 100644 index 7150672..0000000 --- a/algorithm/neat/species/distance.py +++ /dev/null @@ -1,71 +0,0 @@ -from jax import Array, numpy as jnp, vmap - -from core import Gene - - -def distance(gene: Gene, state, genome1, genome2): - return node_distance(gene, state, genome1.nodes, genome2.nodes) + \ - connection_distance(gene, state, genome1.conns, genome2.conns) - - -def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array): - """ - Calculate the distance between nodes of two genomes. - """ - # statistics nodes count of two genomes - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) - - # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row - - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) - - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - - # calculate the distance of homologous nodes - hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr) - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) - - val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight - - return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division - - -def connection_distance(gene: Gene, state, cons1: Array, cons2: Array): - """ - Calculate the distance between connections of two genomes. - Similar process as node_distance. - """ - con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) - - cons = jnp.concatenate((cons1, cons2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row - - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr) - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) - - val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight - - return jnp.where(max_cnt == 0, 0, val / max_cnt) diff --git a/algorithm/neat/species/operations.py b/algorithm/neat/species/operations.py deleted file mode 100644 index b90c741..0000000 --- a/algorithm/neat/species/operations.py +++ /dev/null @@ -1,319 +0,0 @@ -import jax -from jax import numpy as jnp, vmap - -from core import Gene, Genome, State -from utils import rank_elements, fetch_first -from .distance import distance -from .species_info import SpeciesInfo - - -def update_species(state, randkey, fitness): - # update the fitness of each species - species_fitness = update_species_fitness(state, fitness) - - # stagnation species - state, species_fitness = stagnation(state, species_fitness) - - # sort species_info by their fitness. (push nan to the end) - sort_indices = jnp.argsort(species_fitness)[::-1] - - state = state.update( - species_info=state.species_info[sort_indices], - center_genomes=state.center_genomes[sort_indices], - ) - - # decide the number of members of each species by their fitness - spawn_number = cal_spawn_numbers(state) - - # crossover info - winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness) - - return state, winner, loser, elite_mask - - -def update_species_fitness(state, fitness): - """ - obtain the fitness of the species by the fitness of each individual. - use max criterion. - """ - - def aux_func(idx): - s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf) - f = jnp.max(s_fitness) - return f - - return vmap(aux_func)(jnp.arange(state.species_info.size())) - - -def stagnation(state, species_fitness): - """ - stagnation species. - those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. - elitism species never stagnation - """ - - def aux_func(idx): - s_fitness = species_fitness[idx] - sk, bf, li, _ = state.species_info.get(idx) - st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation) - li = jnp.where(s_fitness > bf, state.generation, li) - bf = jnp.where(s_fitness > bf, s_fitness, bf) - - return st, sk, bf, li - - spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0])) - - # elite species will not be stagnation - species_rank = rank_elements(species_fitness) - spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation - - # set stagnation species to nan - species_keys = jnp.where(spe_st, jnp.nan, species_keys) - best_fitness = jnp.where(spe_st, jnp.nan, best_fitness) - last_improved = jnp.where(spe_st, jnp.nan, last_improved) - member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count) - - species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) - - species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count) - - # TODO: Simplify the coded - center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes) - center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns) - - state = state.update( - species_info=species_info, - center_genomes=Genome(center_nodes, center_conns) - ) - - return state, species_fitness - - -def cal_spawn_numbers(state): - """ - decide the number of members of each species by their fitness rank. - the species with higher fitness will have more members - Linear ranking selection - e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] - """ - - species_keys = state.species_info.species_keys - - is_species_valid = ~jnp.isnan(species_keys) - valid_species_num = jnp.sum(is_species_valid) - denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 - - rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1] - spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 - - target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member - - # Avoid too much variation of numbers in a species - previous_size = state.species_info.member_count - spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate - # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) - spawn_number = spawn_number.astype(jnp.int32) - - # must control the sum of spawn_number to be equal to pop_size - error = state.P - jnp.sum(spawn_number) - spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number - - return spawn_number - - -def create_crossover_pair(state, randkey, spawn_number, fitness): - species_size = state.species_info.size() - pop_size = fitness.shape[0] - s_idx = jnp.arange(species_size) - p_idx = jnp.arange(pop_size) - - # def aux_func(key, idx): - def aux_func(key, idx): - members = state.idx2species == state.species_info.species_keys[idx] - members_num = jnp.sum(members) - - members_fitness = jnp.where(members, fitness, -jnp.inf) - sorted_member_indices = jnp.argsort(members_fitness)[::-1] - - elite_size = state.genome_elitism - survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32) - - select_pro = (p_idx < survive_size) / survive_size - fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) - - # elite - fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) - ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) - elite = jnp.where(p_idx < elite_size, True, False) - return fa, ma, elite - - fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) - - spawn_number_cum = jnp.cumsum(spawn_number) - - def aux_func(idx): - loc = jnp.argmax(idx < spawn_number_cum) - - # elite genomes are at the beginning of the species - idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) - return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] - - part1, part2, elite_mask = vmap(aux_func)(p_idx) - - is_part1_win = fitness[part1] >= fitness[part2] - winner = jnp.where(is_part1_win, part1, part2) - loser = jnp.where(is_part1_win, part2, part1) - - return winner, loser, elite_mask - - -def speciate(gene: Gene, state: State): - pop_size, species_size = state.idx2species.shape[0], state.species_info.size() - - # prepare distance functions - o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population - - # idx to specie key - idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species - - # the distance between genomes to its center genomes - o2c_distances = jnp.full((pop_size,), jnp.inf) - - # step 1: find new centers - def cond_func(carry): - i, i2s, cgs, o2c = carry - - return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing - - def body_func(carry): - i, i2s, cgs, o2c = carry - - distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes) - - # find the closest one - closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - - i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) - cgs = cgs.set(i, state.pop_genomes[closest_idx]) - - # the genome with closest_idx will become the new center, thus its distance to center is 0. - o2c = o2c.at[closest_idx].set(0) - - return i + 1, i2s, cgs, o2c - - _, idx2species, center_genomes, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) - - state = state.update( - idx2species=idx2species, - center_genomes=center_genomes, - ) - - # part 2: assign members to each species - def cond_func(carry): - i, i2s, cgs, sk, o2c, nsk = carry - - current_species_existed = ~jnp.isnan(sk[i]) - not_all_assigned = jnp.any(jnp.isnan(i2s)) - not_reach_species_upper_bounds = i < species_size - return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) - - def body_func(carry): - i, i2s, cgs, sk, o2c, nsk = carry - - _, i2s, cgs, sk, o2c, nsk = jax.lax.cond( - jnp.isnan(sk[i]), # whether the current species is existing or not - create_new_species, # if not existing, create a new specie - update_exist_specie, # if existing, update the specie - (i, i2s, cgs, sk, o2c, nsk) - ) - - return i + 1, i2s, cgs, sk, o2c, nsk - - def create_new_species(carry): - i, i2s, cgs, sk, o2c, nsk = carry - - # pick the first one who has not been assigned to any species - idx = fetch_first(jnp.isnan(i2s)) - - # assign it to the new species - # [key, best score, last update generation, member_count] - sk = sk.at[i].set(nsk) - i2s = i2s.at[idx].set(nsk) - o2c = o2c.at[idx].set(0) - - # update center genomes - cgs = cgs.set(i, state.pop_genomes[idx]) - - i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) - - # when a new species is created, it needs to be updated, thus do not change i - return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key - - def update_exist_specie(carry): - i, i2s, cgs, sk, o2c, nsk = carry - - i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) - - # turn to next species - return i + 1, i2s, cgs, sk, o2c, nsk - - def speciate_by_threshold(i, i2s, cgs, sk, o2c): - # distance between such center genome and ppo genomes - - o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes) - close_enough_mask = o2p_distance < state.compatibility_threshold - - # when a genome is not assigned or the distance between its current center is bigger than this center - cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) - # jax.debug.print("{}", o2p_distance) - mask = close_enough_mask & cacheable_mask - - # update species info - i2s = jnp.where(mask, sk[i], i2s) - - # update distance between centers - o2c = jnp.where(mask, o2p_distance, o2c) - - return i2s, o2c - - # update idx2species - _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop( - cond_func, - body_func, - (0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, - state.next_species_key) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition can only happen when the number of species is reached species upper bounds - idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) - - # complete info of species which is created in this generation - new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness) - best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness) - last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved) - - # update members count - def count_members(idx): - key = species_keys[idx] - count = jnp.sum(idx2species == key, dtype=jnp.float32) - count = jnp.where(jnp.isnan(key), jnp.nan, count) - - return count - - member_count = vmap(count_members)(jnp.arange(species_size)) - - return state.update( - species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count), - idx2species=idx2species, - center_genomes=center_genomes, - next_species_key=next_species_key - ) - - -def argmin_with_mask(arr, mask): - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx diff --git a/algorithm/neat/species/species_info.py b/algorithm/neat/species/species_info.py deleted file mode 100644 index 2dc1c86..0000000 --- a/algorithm/neat/species/species_info.py +++ /dev/null @@ -1,55 +0,0 @@ -from jax.tree_util import register_pytree_node_class -import numpy as np -import jax.numpy as jnp - - -@register_pytree_node_class -class SpeciesInfo: - - def __init__(self, species_keys, best_fitness, last_improved, member_count): - self.species_keys = species_keys - self.best_fitness = best_fitness - self.last_improved = last_improved - self.member_count = member_count - - @classmethod - def initialize(cls, state): - species_keys = np.full((state.S,), np.nan, dtype=np.float32) - best_fitness = np.full((state.S,), np.nan, dtype=np.float32) - last_improved = np.full((state.S,), np.nan, dtype=np.float32) - member_count = np.full((state.S,), np.nan, dtype=np.float32) - - species_keys[0] = 0 - best_fitness[0] = -np.inf - last_improved[0] = 0 - member_count[0] = state.P - - return cls(species_keys, best_fitness, last_improved, member_count) - - def __getitem__(self, i): - return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]) - - def get(self, i): - return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i] - - def set(self, idx, value): - species_keys = self.species_keys.at[idx].set(value[0]) - best_fitness = self.best_fitness.at[idx].set(value[1]) - last_improved = self.last_improved.at[idx].set(value[2]) - member_count = self.member_count.at[idx].set(value[3]) - return SpeciesInfo(species_keys, best_fitness, last_improved, member_count) - - def remove(self, idx): - return self.set(idx, jnp.array([jnp.nan] * 4)) - - def size(self): - return self.species_keys.shape[0] - - def tree_flatten(self): - children = self.species_keys, self.best_fitness, self.last_improved, self.member_count - aux_data = None - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) diff --git a/config/__init__.py b/config/__init__.py deleted file mode 100644 index d085c3a..0000000 --- a/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .config import * \ No newline at end of file diff --git a/config/config.py b/config/config.py deleted file mode 100644 index 39f9d8f..0000000 --- a/config/config.py +++ /dev/null @@ -1,107 +0,0 @@ -from dataclasses import dataclass -from utils import Act, Agg - -@dataclass(frozen=True) -class BasicConfig: - seed: int = 42 - fitness_target: float = 1 - generation_limit: int = 1000 - pop_size: int = 100 - - def __post_init__(self): - assert self.pop_size > 0, "the population size must be greater than 0" - - -@dataclass(frozen=True) -class NeatConfig: - network_type: str = "feedforward" - inputs: int = 2 - outputs: int = 1 - max_nodes: int = 50 - max_conns: int = 100 - max_species: int = 10 - - # genome config - compatibility_disjoint: float = 1 - compatibility_weight: float = 0.5 - conn_add: float = 0.4 - conn_delete: float = 0 - node_add: float = 0.2 - node_delete: float = 0 - - # species config - compatibility_threshold: float = 3.5 - species_elitism: int = 2 - max_stagnation: int = 15 - genome_elitism: int = 2 - survival_threshold: float = 0.2 - min_species_size: int = 1 - spawn_number_change_rate: float = 0.5 - - def __post_init__(self): - assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent" - - assert self.inputs > 0, "the inputs number of neat must be greater than 0" - assert self.outputs > 0, "the outputs number of neat must be greater than 0" - - assert self.max_nodes > 0, "the maximum nodes must be greater than 0" - assert self.max_conns > 0, "the maximum connections must be greater than 0" - assert self.max_species > 0, "the maximum species must be greater than 0" - - assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0" - assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0" - assert self.conn_add >= 0, "the connection add probability must be greater than 0" - assert self.conn_delete >= 0, "the connection delete probability must be greater than 0" - assert self.node_add >= 0, "the node add probability must be greater than 0" - assert self.node_delete >= 0, "the node delete probability must be greater than 0" - - assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0" - assert self.species_elitism > 0, "the species elitism must be greater than 0" - assert self.max_stagnation > 0, "the max stagnation must be greater than 0" - assert self.genome_elitism > 0, "the genome elitism must be greater than 0" - assert self.survival_threshold > 0, "the survival threshold must be greater than 0" - assert self.min_species_size > 0, "the min species size must be greater than 0" - assert self.spawn_number_change_rate > 0, "the spawn number change rate must be greater than 0" - - -@dataclass(frozen=True) -class HyperNeatConfig: - below_threshold: float = 0.2 - max_weight: float = 3 - activation: callable = Act.sigmoid - aggregation: callable = Agg.sum - activate_times: int = 5 - inputs: int = 2 - outputs: int = 1 - - def __post_init__(self): - assert self.below_threshold > 0, "the below threshold must be greater than 0" - assert self.max_weight > 0, "the max weight must be greater than 0" - assert self.activate_times > 0, "the activate times must be greater than 0" - assert self.inputs > 0, "the inputs number of hyper neat must be greater than 0" - assert self.outputs > 0, "the outputs number of hyper neat must be greater than 0" - - -@dataclass(frozen=True) -class GeneConfig: - pass - - -@dataclass(frozen=True) -class SubstrateConfig: - pass - - -@dataclass(frozen=True) -class ProblemConfig: - pass - - -@dataclass(frozen=True) -class Config: - basic: BasicConfig = BasicConfig() - neat: NeatConfig = NeatConfig() - hyperneat: HyperNeatConfig = HyperNeatConfig() - gene: GeneConfig = GeneConfig() - substrate: SubstrateConfig = SubstrateConfig() - problem: ProblemConfig = ProblemConfig() diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index 12c9675..0000000 --- a/core/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .algorithm import Algorithm -from .state import State -from .genome import Genome -from .gene import Gene -from .substrate import Substrate -from .problem import Problem diff --git a/core/algorithm.py b/core/algorithm.py deleted file mode 100644 index 0f575a0..0000000 --- a/core/algorithm.py +++ /dev/null @@ -1,50 +0,0 @@ -from functools import partial -import jax -from .state import State -from .genome import Genome - - -class Algorithm: - - def setup(self, randkey, state: State = State()): - """initialize the state of the algorithm""" - - raise NotImplementedError - - @partial(jax.jit, static_argnums=(0,)) - def ask(self, state: State): - """require the population to be evaluated""" - - return self.ask_algorithm(state) - - @partial(jax.jit, static_argnums=(0,)) - def tell(self, state: State, fitness): - """update the state of the algorithm""" - - return self.tell_algorithm(state, fitness) - - @partial(jax.jit, static_argnums=(0,)) - def transform(self, state: State, genome: Genome): - """transform the genome into a neural network""" - - return self.forward_transform(state, genome) - - @partial(jax.jit, static_argnums=(0,)) - def act(self, state: State, inputs, genome: Genome): - return self.forward(state, inputs, genome) - - def forward_transform(self, state: State, genome: Genome): - raise NotImplementedError - - def forward(self, state: State, inputs, genome: Genome): - raise NotImplementedError - - def ask_algorithm(self, state: State): - """ask the specific algorithm for a new population""" - - raise NotImplementedError - - def tell_algorithm(self, state: State, fitness): - """tell the specific algorithm the fitness of the population""" - - raise NotImplementedError diff --git a/core/gene.py b/core/gene.py deleted file mode 100644 index 7d58b07..0000000 --- a/core/gene.py +++ /dev/null @@ -1,40 +0,0 @@ -from config import GeneConfig -from .state import State - - -class Gene: - node_attrs = [] - conn_attrs = [] - - def __init__(self, config: GeneConfig = GeneConfig()): - raise NotImplementedError - - def setup(self, state=State()): - raise NotImplementedError - - def update(self, state): - raise NotImplementedError - - def new_node_attrs(self, state: State): - raise NotImplementedError - - def new_conn_attrs(self, state: State): - raise NotImplementedError - - def mutate_node(self, state: State, randkey, node_attrs): - raise NotImplementedError - - def mutate_conn(self, state: State, randkey, conn_attrs): - raise NotImplementedError - - def distance_node(self, state: State, node_attrs1, node_attrs2): - raise NotImplementedError - - def distance_conn(self, state: State, conn_attrs1, conn_attrs2): - raise NotImplementedError - - def forward_transform(self, state: State, genome): - raise NotImplementedError - - def forward(self, state: State, inputs, transform): - raise NotImplementedError diff --git a/core/genome.py b/core/genome.py deleted file mode 100644 index bf42d42..0000000 --- a/core/genome.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from jax.tree_util import register_pytree_node_class -from jax import numpy as jnp - -from utils.tools import fetch_first - - -@register_pytree_node_class -class Genome: - - def __init__(self, nodes, conns): - self.nodes = nodes - self.conns = conns - - def __repr__(self): - return f"Genome(nodes={self.nodes}, conns={self.conns})" - - def __getitem__(self, idx): - return self.__class__(self.nodes[idx], self.conns[idx]) - - def __eq__(self, other): - nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes))) - conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns))) - return nodes_eq & conns_eq - - def set(self, idx, value: Genome): - return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns)) - - def update(self, nodes, conns): - return self.__class__(nodes, conns) - - def update_nodes(self, nodes): - return self.update(nodes, self.conns) - - def update_conns(self, conns): - return self.update(self.nodes, conns) - - def count(self): - """Count how many nodes and connections are in the genome.""" - nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0])) - conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0])) - return nodes_cnt, conns_cnt - - def add_node(self, new_key: int, attrs): - """ - Add a new node to the genome. - The new node will place at the first NaN row. - """ - exist_keys = self.nodes[:, 0] - pos = fetch_first(jnp.isnan(exist_keys)) - new_nodes = self.nodes.at[pos, 0].set(new_key) - new_nodes = new_nodes.at[pos, 1:].set(attrs) - return self.update_nodes(new_nodes) - - def delete_node_by_pos(self, pos): - """ - Delete a node from the genome. - Delete the node by its pos in nodes. - """ - nodes = self.nodes.at[pos].set(jnp.nan) - return self.update_nodes(nodes) - - def add_conn(self, i_key, o_key, enable: bool, attrs): - """ - Add a new connection to the genome. - The new connection will place at the first NaN row. - """ - con_keys = self.conns[:, 0] - pos = fetch_first(jnp.isnan(con_keys)) - new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) - new_conns = new_conns.at[pos, 3:].set(attrs) - return self.update_conns(new_conns) - - def delete_conn_by_pos(self, pos): - """ - Delete a connection from the genome. - Delete the connection by its idx. - """ - conns = self.conns.at[pos].set(jnp.nan) - return self.update_conns(conns) - - def tree_flatten(self): - children = self.nodes, self.conns - aux_data = None - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) diff --git a/core/problem.py b/core/problem.py deleted file mode 100644 index c57e2fd..0000000 --- a/core/problem.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Callable - -from config import ProblemConfig -from .state import State - - -class Problem: - - jitable = None - - def __init__(self, problem_config: ProblemConfig = ProblemConfig()): - self.config = problem_config - - def evaluate(self, randkey, state: State, act_func: Callable, params): - raise NotImplementedError - - @property - def input_shape(self): - raise NotImplementedError - - @property - def output_shape(self): - raise NotImplementedError - - def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): - """ - show how a genome perform in this problem - """ - raise NotImplementedError diff --git a/core/substrate.py b/core/substrate.py deleted file mode 100644 index 03faa86..0000000 --- a/core/substrate.py +++ /dev/null @@ -1,8 +0,0 @@ -from config import SubstrateConfig - - -class Substrate: - - @staticmethod - def setup(state, config: SubstrateConfig = SubstrateConfig()): - return state diff --git a/examples/brax/ant.py b/examples/brax/ant.py index e8f4b54..60f3d8b 100644 --- a/examples/brax/ant.py +++ b/examples/brax/ant.py @@ -12,7 +12,7 @@ def example_conf(): basic=BasicConfig( seed=42, fitness_target=10000, - pop_size=1000 + pop_size=100 ), neat=NeatConfig( inputs=27, diff --git a/pipeline.py b/pipeline.py index 8694d23..8b6bdcb 100644 --- a/pipeline.py +++ b/pipeline.py @@ -1,7 +1,3 @@ -""" -pipeline for jitable env like func_fit, gymnax -""" - from functools import partial from typing import Type @@ -16,24 +12,28 @@ from core import State, Algorithm, Problem class Pipeline: - def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]): + def __init__( + self, + algorithm: Algorithm, + problem: Problem, + seed: int = 42, + fitness_target: float = 1, + generation_limit: int = 1000, + pop_size: int = 100, + ): + assert problem.jitable, "Currently, problem must be jitable" - assert problem_type.jitable, "problem must be jitable" - - self.config = config self.algorithm = algorithm - self.problem = problem_type(config.problem) + self.problem = problem + self.seed = seed + self.fitness_target = fitness_target + self.generation_limit = generation_limit + self.pop_size = pop_size print(self.problem.input_shape, self.problem.output_shape) - if isinstance(algorithm, NEAT): - assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" - - elif isinstance(algorithm, HyperNEAT): - assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" - - else: - raise NotImplementedError + # TODO: make each algorithm's input_num and output_num + assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" self.act_func = self.algorithm.act @@ -45,19 +45,19 @@ class Pipeline: self.generation_timestamp = None def setup(self): - key = jax.random.PRNGKey(self.config.basic.seed) + key = jax.random.PRNGKey(self.seed) algorithm_key, evaluate_key = jax.random.split(key, 2) - state = State() - state = self.algorithm.setup(algorithm_key, state) - return state.update( - evaluate_key=evaluate_key + + # TODO: Problem should has setup function to maintain state + return State( + alg=self.algorithm.setup(algorithm_key), + pro=self.problem.setup(evaluate_key), ) @partial(jax.jit, static_argnums=(0,)) def step(self, state): - key, sub_key = jax.random.split(state.evaluate_key) - keys = jax.random.split(key, self.config.basic.pop_size) + keys = jax.random.split(key, self.pop_size) pop = self.algorithm.ask(state) @@ -72,7 +72,7 @@ class Pipeline: def auto_run(self, ini_state): state = ini_state - for _ in range(self.config.basic.generation_limit): + for _ in range(self.generation_limit): self.generation_timestamp = time.time() @@ -84,7 +84,7 @@ class Pipeline: self.analysis(state, previous_pop, fitnesses) - if max(fitnesses) >= self.config.basic.fitness_target: + if max(fitnesses) >= self.fitness_target: print("Fitness limit reached!") return state, self.best_genome @@ -120,3 +120,4 @@ class Pipeline: print("start compile") self.step.lower(self, state).compile() print(f"compile finished, cost time: {time.time() - tic}s") + diff --git a/problem/__init__.py b/problem/__init__.py index e69de29..3c2cb07 100644 --- a/problem/__init__.py +++ b/problem/__init__.py @@ -0,0 +1 @@ +from .base import BaseProblem diff --git a/problem/base.py b/problem/base.py new file mode 100644 index 0000000..28118f1 --- /dev/null +++ b/problem/base.py @@ -0,0 +1,44 @@ +from typing import Callable + +from config import ProblemConfig +from core.state import State + + +class BaseProblem: + + jitable = None + + def __init__(self): + pass + + def setup(self, randkey, state: State = State()): + """initialize the state of the problem""" + raise NotImplementedError + + def evaluate(self, randkey, state: State, act_func: Callable, params): + """evaluate one individual""" + raise NotImplementedError + + @property + def input_shape(self): + """ + The input shape for the problem to evaluate + In RL problem, it is the observation space + In function fitting problem, it is the input shape of the function + """ + raise NotImplementedError + + @property + def output_shape(self): + """ + The output shape for the problem to evaluate + In RL problem, it is the action space + In function fitting problem, it is the output shape of the function + """ + raise NotImplementedError + + def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): + """ + show how a genome perform in this problem + """ + raise NotImplementedError diff --git a/problem/func_fit/__init__.py b/problem/func_fit/__init__.py index 9304943..ecad1e1 100644 --- a/problem/func_fit/__init__.py +++ b/problem/func_fit/__init__.py @@ -1,3 +1,3 @@ -from .func_fit import FuncFit, FuncFitConfig +from .func_fit import FuncFit from .xor import XOR from .xor3d import XOR3d diff --git a/problem/func_fit/func_fit.py b/problem/func_fit/func_fit.py index a438972..360796e 100644 --- a/problem/func_fit/func_fit.py +++ b/problem/func_fit/func_fit.py @@ -1,42 +1,35 @@ -from typing import Callable -from dataclasses import dataclass - import jax import jax.numpy as jnp -from config import ProblemConfig -from core import Problem, State +from .. import BaseProblem -@dataclass(frozen=True) -class FuncFitConfig(ProblemConfig): - error_method: str = 'mse' - - def __post_init__(self): - assert self.error_method in {'mse', 'rmse', 'mae', 'mape'} - - -class FuncFit(Problem): +class FuncFit(BaseProblem): jitable = True - def __init__(self, config: FuncFitConfig = FuncFitConfig()): - self.config = config - super().__init__(config) + def __init__(self, + error_method: str = 'mse' + ): + super().__init__() - def evaluate(self, randkey, state: State, act_func: Callable, params): + assert error_method in {'mse', 'rmse', 'mae', 'mape'} + self.error_method = error_method + + + def evaluate(self, randkey, state, act_func, params): predict = act_func(state, self.inputs, params) - if self.config.error_method == 'mse': + if self.error_method == 'mse': loss = jnp.mean((predict - self.targets) ** 2) - elif self.config.error_method == 'rmse': + elif self.error_method == 'rmse': loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2)) - elif self.config.error_method == 'mae': + elif self.error_method == 'mae': loss = jnp.mean(jnp.abs(predict - self.targets)) - elif self.config.error_method == 'mape': + elif self.error_method == 'mape': loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets)) else: @@ -44,7 +37,7 @@ class FuncFit(Problem): return -loss - def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): + def show(self, randkey, state, act_func, params, *args, **kwargs): predict = act_func(state, self.inputs, params) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) loss = -self.evaluate(randkey, state, act_func, params) diff --git a/problem/func_fit/xor.py b/problem/func_fit/xor.py index fed943b..65b1e0c 100644 --- a/problem/func_fit/xor.py +++ b/problem/func_fit/xor.py @@ -1,13 +1,12 @@ import numpy as np -from .func_fit import FuncFit, FuncFitConfig +from .func_fit import FuncFit class XOR(FuncFit): - def __init__(self, config: FuncFitConfig = FuncFitConfig()): - self.config = config - super().__init__(config) + def __init__(self, error_method: str = 'mse'): + super().__init__(error_method) @property def inputs(self): diff --git a/problem/func_fit/xor3d.py b/problem/func_fit/xor3d.py index 2f070f8..1ae8b1b 100644 --- a/problem/func_fit/xor3d.py +++ b/problem/func_fit/xor3d.py @@ -1,13 +1,12 @@ import numpy as np -from .func_fit import FuncFit, FuncFitConfig +from .func_fit import FuncFit class XOR3d(FuncFit): - def __init__(self, config: FuncFitConfig = FuncFitConfig()): - self.config = config - super().__init__(config) + def __init__(self, error_method: str = 'mse'): + super().__init__(error_method) @property def inputs(self): @@ -37,8 +36,8 @@ class XOR3d(FuncFit): @property def input_shape(self): - return (8, 3) + return 8, 3 @property def output_shape(self): - return (8, 1) + return 8, 1 diff --git a/problem/rl_env/brax_env.py b/problem/rl_env/brax_env.py index 7b82e40..9c34501 100644 --- a/problem/rl_env/brax_env.py +++ b/problem/rl_env/brax_env.py @@ -1,28 +1,13 @@ -from dataclasses import dataclass -from typing import Callable - import jax.numpy as jnp from brax import envs -from core import State -from .rl_jit import RLEnv, RLEnvConfig - -@dataclass(frozen=True) -class BraxConfig(RLEnvConfig): - env_name: str = "ant" - backend: str = "generalized" - - def __post_init__(self): - # TODO: Check if env_name is registered - # assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered" - pass +from .rl_jit import RLEnv class BraxEnv(RLEnv): - def __init__(self, config: BraxConfig = BraxConfig()): - super().__init__(config) - self.config = config - self.env = envs.create(env_name=config.env_name, backend=config.backend) + def __init__(self, env_name: str = "ant", backend: str = "generalized"): + super().__init__() + self.env = envs.create(env_name=env_name, backend=backend) def env_step(self, randkey, env_state, action): state = self.env.step(env_state, action) @@ -40,9 +25,7 @@ class BraxEnv(RLEnv): def output_shape(self): return (self.env.action_size,) - def show(self, randkey, state: State, act_func: Callable, params, save_path=None, height=512, width=512, - duration=0.1, *args, - **kwargs): + def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs): import jax import imageio @@ -56,8 +39,7 @@ class BraxEnv(RLEnv): def step(key, env_state, obs): key, _ = jax.random.split(key) - net_out = act_func(state, obs, params) - action = self.config.output_transform(net_out) + action = act_func(state, obs, params) obs, env_state, r, done, _ = self.step(randkey, env_state, action) return key, env_state, obs, r, done @@ -72,7 +54,6 @@ class BraxEnv(RLEnv): def create_gif(image_list, gif_name, duration): with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: for image in image_list: - # 确保图像的数据类型正确 formatted_image = np.array(image, dtype=np.uint8) writer.append_data(formatted_image) diff --git a/problem/rl_env/gymnax_env.py b/problem/rl_env/gymnax_env.py index 872c63e..5912df5 100644 --- a/problem/rl_env/gymnax_env.py +++ b/problem/rl_env/gymnax_env.py @@ -1,26 +1,15 @@ -from dataclasses import dataclass -from typing import Callable - import gymnax -from core import State -from .rl_jit import RLEnv, RLEnvConfig +from .rl_jit import RLEnv -@dataclass(frozen=True) -class GymNaxConfig(RLEnvConfig): - env_name: str = "CartPole-v1" - - def __post_init__(self): - assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered" - class GymNaxEnv(RLEnv): - def __init__(self, config: GymNaxConfig = GymNaxConfig()): - super().__init__(config) - self.config = config - self.env, self.env_params = gymnax.make(config.env_name) + def __init__(self, env_name): + super().__init__() + assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" + self.env, self.env_params = gymnax.make(env_name) def env_step(self, randkey, env_state, action): return self.env.step(randkey, env_state, action, self.env_params) @@ -36,5 +25,5 @@ class GymNaxEnv(RLEnv): def output_shape(self): return self.env.action_space(self.env_params).shape - def show(self, randkey, state: State, act_func: Callable, params): + def show(self, randkey, state, act_func, params, *args, **kwargs): raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/problem/rl_env/rl_jit.py b/problem/rl_env/rl_jit.py index 84a512b..07e68f6 100644 --- a/problem/rl_env/rl_jit.py +++ b/problem/rl_env/rl_jit.py @@ -1,28 +1,18 @@ -from dataclasses import dataclass -from typing import Callable from functools import partial import jax -from config import ProblemConfig +from .. import BaseProblem -from core import Problem, State - - -@dataclass(frozen=True) -class RLEnvConfig(ProblemConfig): - output_transform: Callable = lambda x: x - - -class RLEnv(Problem): +class RLEnv(BaseProblem): jitable = True - def __init__(self, config: RLEnvConfig = RLEnvConfig()): - super().__init__(config) - self.config = config + # TODO: move output transform to algorithm + def __init__(self): + super().__init__() - def evaluate(self, randkey, state: State, act_func: Callable, params): + def evaluate(self, randkey, state, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) @@ -31,8 +21,7 @@ class RLEnv(Problem): return ~done def body_func(carry): obs, env_state, rng, _, tr = carry # total reward - net_out = act_func(state, obs, params) - action = self.config.output_transform(net_out) + action = act_func(state, obs, params) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_rng, _ = jax.random.split(rng) return next_obs, next_env_state, next_rng, done, tr + reward @@ -67,5 +56,5 @@ class RLEnv(Problem): def output_shape(self): raise NotImplementedError - def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): + def show(self, randkey, state, act_func, params, *args, **kwargs): raise NotImplementedError diff --git a/t.py b/t.py new file mode 100644 index 0000000..2318d8b --- /dev/null +++ b/t.py @@ -0,0 +1,64 @@ +from algorithm.neat import * +from utils import Act, Agg + +import jax, jax.numpy as jnp + + +def main(): + + # index, bias, response, activation, aggregation + nodes = jnp.array([ + [0, 0, 1, 0, 0], # in[0] + [1, 0, 1, 0, 0], # in[1] + [2, 0.5, 1, 0, 0], # out[0], + [3, 1, 1, 0, 0], # hidden[0], + [4, -1, 1, 0, 0], # hidden[1], + ]) + + # in_node, out_node, enable, weight + conns = jnp.array([ + [0, 3, 1, 0.5], # in[0] -> hidden[0] + [1, 4, 1, 0.5], # in[1] -> hidden[1] + [3, 2, 1, 0.5], # hidden[0] -> out[0] + [4, 2, 1, 0.5], # hidden[1] -> out[0] + ]) + + genome = RecurrentGenome( + num_inputs=2, + num_outputs=1, + node_gene=DefaultNodeGene( + activation_default=Act.identity, + activation_options=(Act.identity, ), + aggregation_default=Agg.sum, + aggregation_options=(Agg.sum, ), + ), + activate_time=3 + ) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([0, 0]) + outputs = genome.forward(inputs, transformed) + print(outputs) + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) + print(outputs) + expected: [[0.5], [0.75], [0.75], [1]] + + print('\n-------------------------------------------------------\n') + + conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] + print(conns) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) + print(outputs) + expected: [[0.5], [0.75], [0.5], [0.75]] + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_genome.py b/test/test_genome.py new file mode 100644 index 0000000..d77a89a --- /dev/null +++ b/test/test_genome.py @@ -0,0 +1,113 @@ +from algorithm.neat import * +from utils import Act, Agg + +import jax, jax.numpy as jnp + + +def test_default(): + + # index, bias, response, activation, aggregation + nodes = jnp.array([ + [0, 0, 1, 0, 0], # in[0] + [1, 0, 1, 0, 0], # in[1] + [2, 0.5, 1, 0, 0], # out[0], + [3, 1, 1, 0, 0], # hidden[0], + [4, -1, 1, 0, 0], # hidden[1], + ]) + + # in_node, out_node, enable, weight + conns = jnp.array([ + [0, 3, 1, 0.5], # in[0] -> hidden[0] + [1, 4, 1, 0.5], # in[1] -> hidden[1] + [3, 2, 1, 0.5], # hidden[0] -> out[0] + [4, 2, 1, 0.5], # hidden[1] -> out[0] + ]) + + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + node_gene=DefaultNodeGene( + activation_default=Act.identity, + activation_options=(Act.identity, ), + aggregation_default=Agg.sum, + aggregation_options=(Agg.sum, ), + ), + ) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) + print(outputs) + assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) + # expected: [[0.5], [0.75], [0.75], [1]] + + print('\n-------------------------------------------------------\n') + + conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] + print(conns) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) + print(outputs) + assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) + # expected: [[0.5], [0.75], [0.5], [0.75]] + + +def test_recurrent(): + + # index, bias, response, activation, aggregation + nodes = jnp.array([ + [0, 0, 1, 0, 0], # in[0] + [1, 0, 1, 0, 0], # in[1] + [2, 0.5, 1, 0, 0], # out[0], + [3, 1, 1, 0, 0], # hidden[0], + [4, -1, 1, 0, 0], # hidden[1], + ]) + + # in_node, out_node, enable, weight + conns = jnp.array([ + [0, 3, 1, 0.5], # in[0] -> hidden[0] + [1, 4, 1, 0.5], # in[1] -> hidden[1] + [3, 2, 1, 0.5], # hidden[0] -> out[0] + [4, 2, 1, 0.5], # hidden[1] -> out[0] + ]) + + genome = RecurrentGenome( + num_inputs=2, + num_outputs=1, + node_gene=DefaultNodeGene( + activation_default=Act.identity, + activation_options=(Act.identity, ), + aggregation_default=Agg.sum, + aggregation_options=(Agg.sum, ), + ), + activate_time=3, + ) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) + print(outputs) + assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) + # expected: [[0.5], [0.75], [0.75], [1]] + + print('\n-------------------------------------------------------\n') + + conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] + print(conns) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) + outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) + print(outputs) + assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) + # expected: [[0.5], [0.75], [0.5], [0.75]] \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index f8237b0..077fcd4 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,5 @@ from .activation import Act, act from .aggregation import Agg, agg from .tools import * -from .graph import * \ No newline at end of file +from .graph import * +from .state import State \ No newline at end of file diff --git a/utils/aggregation.py b/utils/aggregation.py index 4b94fe4..2e5d94a 100644 --- a/utils/aggregation.py +++ b/utils/aggregation.py @@ -57,10 +57,8 @@ def agg(idx, z, agg_funcs): """ idx = jnp.asarray(idx, dtype=jnp.int32) - def all_nan(): - return 0. - - def not_all_nan(): - return jax.lax.switch(idx, agg_funcs, z) - - return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) + return jax.lax.cond( + jnp.all(jnp.isnan(z)), + lambda: jnp.nan, # all inputs are nan + lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise + ) diff --git a/core/state.py b/utils/state.py similarity index 100% rename from core/state.py rename to utils/state.py diff --git a/utils/tools.py b/utils/tools.py index 3f22490..0103296 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -5,13 +5,11 @@ import jax from jax import numpy as jnp, Array, jit, vmap I_INT = np.iinfo(jnp.int32).max # infinite int -EMPTY_NODE = np.full((1, 5), jnp.nan) -EMPTY_CON = np.full((1, 4), jnp.nan) def unflatten_conns(nodes, conns): """ - transform the (C, CL) connections to (CL-2, N, N) + transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index) :return: """ N = nodes.shape[0] @@ -66,4 +64,43 @@ def rank_elements(array, reverse=False): """ if not reverse: array = -array - return jnp.argsort(jnp.argsort(array)) \ No newline at end of file + return jnp.argsort(jnp.argsort(array)) + + +@jit +def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): + k1, k2, k3 = jax.random.split(key, num=3) + noise = jax.random.normal(k1, ()) * mutate_power + replace = jax.random.normal(k2, ()) * init_std + init_mean + r = jax.random.uniform(k3, ()) + + val = jnp.where( + r < mutate_rate, + val + noise, + jnp.where( + (mutate_rate < r) & (r < mutate_rate + replace_rate), + replace, + val + ) + ) + + return val + + +@jit +def mutate_int(key, val, options, replace_rate): + k1, k2 = jax.random.split(key, num=2) + r = jax.random.uniform(k1, ()) + + val = jnp.where( + r < replace_rate, + jax.random.choice(k2, options), + val + ) + + return val + +def argmin_with_mask(arr, mask): + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx \ No newline at end of file