diff --git a/algorithm/__init__.py b/algorithm/__init__.py index 1bfd121..e69de29 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -1,4 +0,0 @@ -from .base import Algorithm -from .state import State -from .neat import NEAT -from .hyperneat import HyperNEAT diff --git a/algorithm/base.py b/algorithm/base.py deleted file mode 100644 index 188e96d..0000000 --- a/algorithm/base.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Callable - -from .state import State - -EMPTY = lambda *args: args - - -class Algorithm: - - def __init__(self): - self.tell: Callable = EMPTY - self.ask: Callable = EMPTY - self.forward: Callable = EMPTY - self.forward_transform: Callable = EMPTY - - def setup(self, randkey, state=State()): - pass diff --git a/algorithm/hyperneat/__init__.py b/algorithm/hyperneat/__init__.py deleted file mode 100644 index 17af79f..0000000 --- a/algorithm/hyperneat/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .hyperneat import HyperNEAT -from .substrate import BaseSubstrate diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py deleted file mode 100644 index 79ff6e0..0000000 --- a/algorithm/hyperneat/hyperneat.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Type - -import jax -import numpy as np - -from .substrate import BaseSubstrate, analysis_substrate -from .hyperneat_gene import HyperNEATGene -from algorithm import State, Algorithm, neat - - -class HyperNEAT(Algorithm): - - def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]): - super().__init__() - self.config = config - self.gene_type = gene_type - self.substrate = substrate - self.neat = neat.NEAT(config, gene_type) - - self.tell = create_tell(self.neat) - self.forward_transform = create_forward_transform(config, self.neat) - self.forward = HyperNEATGene.create_forward(config) - - def setup(self, randkey, state=State()): - state = state.update( - below_threshold=self.config['below_threshold'], - max_weight=self.config['max_weight'] - ) - - state = self.substrate.setup(state, self.config) - h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) - h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] - h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) - h_conns[:, 0:2] = correspond_keys - - state = state.update( - # h is short for hyperneat - h_input_idx=h_input_idx, - h_output_idx=h_output_idx, - h_hidden_idx=h_hidden_idx, - query_coors=query_coors, - correspond_keys=correspond_keys, - h_nodes=h_nodes, - h_conns=h_conns - ) - state = self.neat.setup(randkey, state=state) - - self.config['h_input_idx'] = h_input_idx - self.config['h_output_idx'] = h_output_idx - - return state - - -def create_tell(neat_instance): - def tell(state, fitness): - return neat_instance.tell(state, fitness) - - return tell - - -def create_forward_transform(config, neat_instance): - def forward_transform(state, nodes, conns): - t = neat_instance.forward_transform(state, nodes, conns) - batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None)) - query_res = batch_forward_func(state.query_coors, t) # hyperneat connections - h_nodes = state.h_nodes - h_conns = state.h_conns.at[:, 2:].set(query_res) - return HyperNEATGene.forward_transform(state, h_nodes, h_conns) - - return forward_transform diff --git a/algorithm/hyperneat/hyperneat_gene.py b/algorithm/hyperneat/hyperneat_gene.py deleted file mode 100644 index 247c95c..0000000 --- a/algorithm/hyperneat/hyperneat_gene.py +++ /dev/null @@ -1,54 +0,0 @@ -import jax -from jax import numpy as jnp, vmap - -from algorithm.neat import BaseGene -from algorithm.neat.gene import Activation -from algorithm.neat.gene import Aggregation - - -class HyperNEATGene(BaseGene): - node_attrs = [] # no node attributes - conn_attrs = ['weight'] - - @staticmethod - def forward_transform(state, nodes, conns): - N = nodes.shape[0] - u_conns = jnp.zeros((N, N), dtype=jnp.float32) - - in_keys = jnp.asarray(conns[:, 0], jnp.int32) - out_keys = jnp.asarray(conns[:, 1], jnp.int32) - weights = conns[:, 2] - - u_conns = u_conns.at[in_keys, out_keys].set(weights) - return nodes, u_conns - - @staticmethod - def create_forward(config): - act = Activation.name2func[config['h_activation']] - agg = Aggregation.name2func[config['h_aggregation']] - - batch_act, batch_agg = vmap(act), vmap(agg) - - def forward(inputs, transform): - - inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) - nodes, weights = transform - - input_idx = config['h_input_idx'] - output_idx = config['h_output_idx'] - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - def body_func(i, values): - values = values.at[input_idx].set(inputs_with_bias) - nodes_ins = values * weights.T - values = batch_agg(nodes_ins) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(values) # z = act(z) - return values - - vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals) - return vals[output_idx] - - return forward diff --git a/algorithm/hyperneat/substrate/__init__.py b/algorithm/hyperneat/substrate/__init__.py deleted file mode 100644 index 366d01b..0000000 --- a/algorithm/hyperneat/substrate/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import BaseSubstrate -from .tools import analysis_substrate diff --git a/algorithm/hyperneat/substrate/base.py b/algorithm/hyperneat/substrate/base.py deleted file mode 100644 index 586ce20..0000000 --- a/algorithm/hyperneat/substrate/base.py +++ /dev/null @@ -1,12 +0,0 @@ -import numpy as np - - -class BaseSubstrate: - - @staticmethod - def setup(state, config): - return state.update( - input_coors=np.asarray(config['input_coors'], dtype=np.float32), - output_coors=np.asarray(config['output_coors'], dtype=np.float32), - hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32), - ) diff --git a/algorithm/hyperneat/substrate/tools.py b/algorithm/hyperneat/substrate/tools.py deleted file mode 100644 index 9eb4720..0000000 --- a/algorithm/hyperneat/substrate/tools.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Type - -import numpy as np - -from .base import BaseSubstrate - - -def analysis_substrate(state): - cd = state.input_coors.shape[1] # coordinate dimensions - si = state.input_coors.shape[0] # input coordinate size - so = state.output_coors.shape[0] # output coordinate size - sh = state.hidden_coors.shape[0] # hidden coordinate size - - input_idx = np.arange(si) - output_idx = np.arange(si, si + so) - hidden_idx = np.arange(si + so, si + so + sh) - - total_conns = si * sh + sh * sh + sh * so - query_coors = np.zeros((total_conns, cd * 2)) - correspond_keys = np.zeros((total_conns, 2)) - - # connect input to hidden - aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) - query_coors[0: si * sh, :] = aux_coors - correspond_keys[0: si * sh, :] = aux_keys - - # connect hidden to hidden - aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) - query_coors[si * sh: si * sh + sh * sh, :] = aux_coors - correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys - - # connect hidden to output - aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) - query_coors[si * sh + sh * sh:, :] = aux_coors - correspond_keys[si * sh + sh * sh:, :] = aux_keys - - return input_idx, output_idx, hidden_idx, query_coors, correspond_keys - - -def cartesian_product(keys1, keys2, coors1, coors2): - len1 = keys1.shape[0] - len2 = keys2.shape[0] - - repeated_coors1 = np.repeat(coors1, len2, axis=0) - repeated_keys1 = np.repeat(keys1, len2) - - tiled_coors2 = np.tile(coors2, (len1, 1)) - tiled_keys2 = np.tile(keys2, len1) - - new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) - correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) - - return new_coors, correspond_keys diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index dc30798..e69de29 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,2 +0,0 @@ -from .neat import NEAT -from .gene import BaseGene, NormalGene, RecurrentGene diff --git a/algorithm/neat/ga/__init__.py b/algorithm/neat/ga/__init__.py new file mode 100644 index 0000000..4fb8380 --- /dev/null +++ b/algorithm/neat/ga/__init__.py @@ -0,0 +1,2 @@ +from .crossover import crossover +from .mutate import create_mutate diff --git a/algorithm/neat/genome/crossover.py b/algorithm/neat/ga/crossover.py similarity index 73% rename from algorithm/neat/genome/crossover.py rename to algorithm/neat/ga/crossover.py index 302c82d..80810f0 100644 --- a/algorithm/neat/genome/crossover.py +++ b/algorithm/neat/ga/crossover.py @@ -1,8 +1,10 @@ import jax -from jax import jit, Array, numpy as jnp +from jax import Array, numpy as jnp + +from core import Genome -def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Array): +def crossover(randkey, genome1: Genome, genome2: Genome): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) @@ -10,20 +12,22 @@ def crossover(randkey, nodes1: Array, conns1: Array, nodes2: Array, conns2: Arra randkey_1, randkey_2, key= jax.random.split(randkey, 3) # crossover nodes - keys1, keys2 = nodes1[:, 0], nodes2[:, 0] + keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0] # make homologous genes align in nodes2 align with nodes1 - nodes2 = align_array(keys1, keys2, nodes2, False) - + nodes2 = align_array(keys1, keys2, genome2.nodes, False) + nodes1 = genome1.nodes # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) # crossover connections - con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] - cons2 = align_array(con_keys1, con_keys2, conns2, True) - new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(cons2), conns1, crossover_gene(randkey_2, conns1, cons2)) + con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2] + conns2 = align_array(con_keys1, con_keys2, genome2.conns, True) + conns1 = genome1.conns - return new_nodes, new_cons + new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2)) + + return genome1.update(new_nodes, new_cons) def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array: @@ -63,4 +67,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: only gene with the same key will be crossover, thus don't need to consider change key """ r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) + return jnp.where(r > 0.5, g1, g2) \ No newline at end of file diff --git a/algorithm/neat/ga/mutate.py b/algorithm/neat/ga/mutate.py new file mode 100644 index 0000000..0a18141 --- /dev/null +++ b/algorithm/neat/ga/mutate.py @@ -0,0 +1,189 @@ +from typing import Tuple, Type + +import jax +from jax import Array, numpy as jnp, vmap + +from config import NeatConfig +from core import State, Gene, Genome +from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns + + +def create_mutate(config: NeatConfig, gene_type: Type[Gene]): + """ + Create function to mutate a single genome + """ + + def mutate_structure(state: State, randkey, genome: Genome, new_node_key): + + def mutate_add_node(key_, genome_: Genome): + i_key, o_key, idx = choice_connection_key(key_, genome_.conns) + + def nothing(): + return genome_ + + def successful_add_node(): + # disable the connection + new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False)) + + # add a new node + new_genome = new_genome.add_node(new_node_key, gene_type.new_node_attrs(state)) + + # add two new connections + new_genome = new_genome.add_conn(i_key, new_node_key, True, gene_type.new_conn_attrs(state)) + new_genome = new_genome.add_conn(new_node_key, o_key, True, gene_type.new_conn_attrs(state)) + + return new_genome + + # if from_idx == I_INT, that means no connection exist, do nothing + return jax.lax.cond(idx == I_INT, nothing, successful_add_node) + + def mutate_delete_node(key_, genome_: Genome): + # TODO: Do we really need to delete a node? + # randomly choose a node + key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=False, allow_output_keys=False) + def nothing(): + return genome_ + + def successful_delete_node(): + # delete the node + new_genome = genome_.delete_node_by_pos(idx) + + # delete all connections + new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None], + jnp.nan, new_genome.conns) + + return new_genome.update_conns(new_conns) + + return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) + + def mutate_add_conn(key_, genome_: Genome): + # randomly choose two nodes + k1_, k2_ = jax.random.split(key_, num=2) + i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=True, allow_output_keys=True) + o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx, + allow_input_keys=False, allow_output_keys=True) + + conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key)) + + def nothing(): + return genome_ + + def successful(): + return genome_.add_conn(i_key, o_key, True, gene_type.new_conn_attrs(state)) + + def already_exist(): + return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True)) + + + is_already_exist = conn_pos != I_INT + + if config.network_type == 'feedforward': + u_cons = unflatten_conns(genome_.nodes, genome_.conns) + cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) + is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx) + + choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) + return jax.lax.switch(choice, [already_exist, nothing, successful]) + + elif config.network_type == 'recurrent': + return jax.lax.cond(is_already_exist, already_exist, successful) + + else: + raise ValueError(f"Invalid network type: {config.network_type}") + + def mutate_delete_conn(key_, genome_: Genome): + # randomly choose a connection + i_key, o_key, idx = choice_connection_key(key_, genome_.conns) + + def nothing(): + return genome_ + + def successfully_delete_connection(): + return genome_.delete_conn_by_pos(idx) + + return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) + + k1, k2, k3, k4 = jax.random.split(randkey, num=4) + r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) + + def no(k, g): + return g + + genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome) + genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome) + genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome) + genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome) + + return genome + + def mutate_values(state: State, randkey, genome: Genome): + k1, k2 = jax.random.split(randkey, num=2) + nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) + conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) + + nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:] + + new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys) + new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys) + + # nan nodes not changed + new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) + new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) + + new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs) + new_conns = genome.conns.at[:, 3:].set(new_conns_attrs) + + return genome.update(new_nodes, new_conns) + + def mutate(state, randkey, genome: Genome, new_node_key): + k1, k2 = jax.random.split(randkey) + + genome = mutate_structure(state, k1, genome, new_node_key) + genome = mutate_values(state, k2, genome) + + return genome + + return mutate + + +def choice_node_key(rand_key: Array, nodes: Array, + input_keys: Array, output_keys: Array, + allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: + """ + Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. + :param rand_key: + :param nodes: + :param input_keys: + :param output_keys: + :param allow_input_keys: + :param allow_output_keys: + :return: return its key and position(idx) + """ + + node_keys = nodes[:, 0] + mask = ~jnp.isnan(node_keys) + + if not allow_input_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) + + if not allow_output_keys: + mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) + + idx = fetch_random(rand_key, mask) + key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + return key, idx + + +def choice_connection_key(rand_key: Array, conns: Array): + """ + Randomly choose a connection key from the given connections. + :return: i_key, o_key, idx + """ + + idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0])) + i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan) + + return i_key, o_key, idx \ No newline at end of file diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index e1188c1..02af6ce 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1,6 +1 @@ -from .base import BaseGene -from .normal import NormalGene -from .activation import Activation -from .aggregation import Aggregation -from .recurrent import RecurrentGene - +from .normal import NormalGene, NormalGeneConfig diff --git a/algorithm/neat/gene/base.py b/algorithm/neat/gene/base.py deleted file mode 100644 index 4f2e43d..0000000 --- a/algorithm/neat/gene/base.py +++ /dev/null @@ -1,42 +0,0 @@ -from jax import Array, numpy as jnp, vmap - - -class BaseGene: - node_attrs = [] - conn_attrs = [] - - @staticmethod - def setup(state, config): - return state - - @staticmethod - def new_node_attrs(state): - return jnp.zeros(0) - - @staticmethod - def new_conn_attrs(state): - return jnp.zeros(0) - - @staticmethod - def mutate_node(state, attrs: Array, key): - return attrs - - @staticmethod - def mutate_conn(state, attrs: Array, key): - return attrs - - @staticmethod - def distance_node(state, node1: Array, node2: Array): - return node1 - - @staticmethod - def distance_conn(state, conn1: Array, conn2: Array): - return conn1 - - @staticmethod - def forward_transform(state, nodes, conns): - return nodes, conns - - @staticmethod - def create_forward(config): - return None diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index 0cc1c80..b07d28c 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -1,45 +1,100 @@ +from dataclasses import dataclass +from typing import Tuple + import jax from jax import Array, numpy as jnp -from .base import BaseGene -from .activation import Activation -from .aggregation import Aggregation -from algorithm.utils import unflatten_connections, I_INT -from ..genome import topological_sort +from config import GeneConfig +from core import Gene, Genome, State +from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT -class NormalGene(BaseGene): +@dataclass(frozen=True) +class NormalGeneConfig(GeneConfig): + bias_init_mean: float = 0.0 + bias_init_std: float = 1.0 + bias_mutate_power: float = 0.5 + bias_mutate_rate: float = 0.7 + bias_replace_rate: float = 0.1 + + response_init_mean: float = 1.0 + response_init_std: float = 0.0 + response_mutate_power: float = 0.5 + response_mutate_rate: float = 0.7 + response_replace_rate: float = 0.1 + + activation_default: str = 'sigmoid' + activation_options: Tuple[str] = ('sigmoid',) + activation_replace_rate: float = 0.1 + + aggregation_default: str = 'sum' + aggregation_options: Tuple[str] = ('sum',) + aggregation_replace_rate: float = 0.1 + + weight_init_mean: float = 0.0 + weight_init_std: float = 1.0 + weight_mutate_power: float = 0.5 + weight_mutate_rate: float = 0.8 + weight_replace_rate: float = 0.1 + + def __post_init__(self): + assert self.bias_init_std >= 0.0 + assert self.bias_mutate_power >= 0.0 + assert self.bias_mutate_rate >= 0.0 + assert self.bias_replace_rate >= 0.0 + + assert self.response_init_std >= 0.0 + assert self.response_mutate_power >= 0.0 + assert self.response_mutate_rate >= 0.0 + assert self.response_replace_rate >= 0.0 + + assert self.activation_default == self.activation_options[0] + + for name in self.activation_options: + assert name in Activation.name2func, f"Activation function: {name} not found" + + assert self.aggregation_default == self.aggregation_options[0] + + assert self.aggregation_default in Aggregation.name2func, \ + f"Aggregation function: {self.aggregation_default} not found" + + for name in self.aggregation_options: + assert name in Aggregation.name2func, f"Aggregation function: {name} not found" + + +class NormalGene(Gene): node_attrs = ['bias', 'response', 'aggregation', 'activation'] conn_attrs = ['weight'] @staticmethod - def setup(state, config): + def setup(config: NormalGeneConfig, state: State = State()): + return state.update( - bias_init_mean=config['bias_init_mean'], - bias_init_std=config['bias_init_std'], - bias_mutate_power=config['bias_mutate_power'], - bias_mutate_rate=config['bias_mutate_rate'], - bias_replace_rate=config['bias_replace_rate'], + bias_init_mean=config.bias_init_mean, + bias_init_std=config.bias_init_std, + bias_mutate_power=config.bias_mutate_power, + bias_mutate_rate=config.bias_mutate_rate, + bias_replace_rate=config.bias_replace_rate, - response_init_mean=config['response_init_mean'], - response_init_std=config['response_init_std'], - response_mutate_power=config['response_mutate_power'], - response_mutate_rate=config['response_mutate_rate'], - response_replace_rate=config['response_replace_rate'], + response_init_mean=config.response_init_mean, + response_init_std=config.response_init_std, + response_mutate_power=config.response_mutate_power, + response_mutate_rate=config.response_mutate_rate, + response_replace_rate=config.response_replace_rate, - activation_default=config['activation_default'], - activation_options=config['activation_options'], - activation_replace_rate=config['activation_replace_rate'], + activation_replace_rate=config.activation_replace_rate, + activation_default=0, + activation_options=jnp.arange(len(config.activation_options)), - aggregation_default=config['aggregation_default'], - aggregation_options=config['aggregation_options'], - aggregation_replace_rate=config['aggregation_replace_rate'], + aggregation_replace_rate=config.aggregation_replace_rate, + aggregation_default=0, + aggregation_options=jnp.arange(len(config.aggregation_options)), - weight_init_mean=config['weight_init_mean'], - weight_init_std=config['weight_init_std'], - weight_mutate_power=config['weight_mutate_power'], - weight_mutate_rate=config['weight_mutate_rate'], - weight_replace_rate=config['weight_replace_rate'], + weight_init_mean=config.weight_init_mean, + weight_init_std=config.weight_init_std, + weight_mutate_power=config.weight_mutate_power, + weight_mutate_rate=config.weight_mutate_rate, + weight_replace_rate=config.weight_replace_rate, ) @staticmethod @@ -84,20 +139,20 @@ class NormalGene(BaseGene): return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight @staticmethod - def forward_transform(state, nodes, conns): - u_conns = unflatten_connections(nodes, conns) + def forward_transform(state: State, genome: Genome): + u_conns = unflatten_conns(genome.nodes, genome.conns) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) # remove enable attr u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - seqs = topological_sort(nodes, conn_enable) + seqs = topological_sort(genome.nodes, conn_enable) - return seqs, nodes, u_conns + return seqs, genome.nodes, u_conns @staticmethod - def create_forward(config): - config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']] - config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']] + def create_forward(state: State, config: NormalGeneConfig): + activation_funcs = [Activation.name2func[name] for name in config.activation_options] + aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] def act(idx, z): """ @@ -105,7 +160,7 @@ class NormalGene(BaseGene): """ idx = jnp.asarray(idx, dtype=jnp.int32) # change idx from float to int - res = jax.lax.switch(idx, config['activation_funcs'], z) + res = jax.lax.switch(idx, activation_funcs, z) return res def agg(idx, z): @@ -118,14 +173,13 @@ class NormalGene(BaseGene): return 0. def not_all_nan(): - return jax.lax.switch(idx, config['aggregation_funcs'], z) + return jax.lax.switch(idx, aggregation_funcs, z) return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) - def forward(inputs, transform) -> Array: + def forward(inputs, transformed) -> Array: """ - jax forward for single input shaped (input_num, ) - nodes, connections are a single genome + forward for single input shaped (input_num, ) :argument inputs: (input_num, ) :argument cal_seqs: (N, ) @@ -135,10 +189,10 @@ class NormalGene(BaseGene): :return (output_num, ) """ - cal_seqs, nodes, cons = transform + cal_seqs, nodes, cons = transformed - input_idx = config['input_idx'] - output_idx = config['output_idx'] + input_idx = state.input_idx + output_idx = state.output_idx N = nodes.shape[0] ini_vals = jnp.full((N,), jnp.nan) diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py deleted file mode 100644 index f723748..0000000 --- a/algorithm/neat/gene/recurrent.py +++ /dev/null @@ -1,90 +0,0 @@ -import jax -from jax import Array, numpy as jnp, vmap - -from .normal import NormalGene -from .activation import Activation -from .aggregation import Aggregation -from algorithm.utils import unflatten_connections - - -class RecurrentGene(NormalGene): - - @staticmethod - def forward_transform(state, nodes, conns): - u_conns = unflatten_connections(nodes, conns) - - # remove un-enable connections and remove enable attr - conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - - return nodes, u_conns - - @staticmethod - def create_forward(config): - config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']] - config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']] - - def act(idx, z): - """ - calculate activation function for each node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - # change idx from float to int - res = jax.lax.switch(idx, config['activation_funcs'], z) - return res - - def agg(idx, z): - """ - calculate activation function for inputs of node - """ - idx = jnp.asarray(idx, dtype=jnp.int32) - - def all_nan(): - return 0. - - def not_all_nan(): - return jax.lax.switch(idx, config['aggregation_funcs'], z) - - return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) - - batch_act, batch_agg = vmap(act), vmap(agg) - - def forward(inputs, transform) -> Array: - """ - jax forward for single input shaped (input_num, ) - nodes, connections are a single genome - - :argument inputs: (input_num, ) - :argument cal_seqs: (N, ) - :argument nodes: (N, 5) - :argument connections: (2, N, N) - - :return (output_num, ) - """ - - nodes, cons = transform - - input_idx = config['input_idx'] - output_idx = config['output_idx'] - - N = nodes.shape[0] - vals = jnp.full((N,), 0.) - - weights = cons[0, :] - - def body_func(i, values): - values = values.at[input_idx].set(inputs) - nodes_ins = values * weights.T - values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias - values = batch_act(nodes[:, 3], values) # z = act(z) - return values - - # for i in range(config['activate_times']): - # vals = body_func(i, vals) - # - # return vals[output_idx] - vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals) - return vals[output_idx] - - return forward diff --git a/algorithm/neat/genome/__init__.py b/algorithm/neat/genome/__init__.py deleted file mode 100644 index ec0b7b9..0000000 --- a/algorithm/neat/genome/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .basic import initialize_genomes -from .mutate import create_mutate -from .distance import create_distance -from .crossover import crossover -from .graph import topological_sort \ No newline at end of file diff --git a/algorithm/neat/genome/basic.py b/algorithm/neat/genome/basic.py deleted file mode 100644 index 5635280..0000000 --- a/algorithm/neat/genome/basic.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Type, Tuple - -import numpy as np -import jax -from jax import Array, numpy as jnp - -from algorithm import State -from ..gene import BaseGene -from algorithm.utils import fetch_first - - -def initialize_genomes(state: State, gene_type: Type[BaseGene]): - o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes - o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections - - input_idx = state.input_idx - output_idx = state.output_idx - new_node_key = max([*input_idx, *output_idx]) + 1 - - o_nodes[input_idx, 0] = input_idx - o_nodes[output_idx, 0] = output_idx - o_nodes[new_node_key, 0] = new_node_key - o_nodes[np.concatenate([input_idx, output_idx]), 1:] = jax.device_get(gene_type.new_node_attrs(state)) - o_nodes[new_node_key, 1:] = jax.device_get(gene_type.new_node_attrs(state)) - - input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] - o_conns[input_idx, 0:2] = input_conns # in key, out key - o_conns[input_idx, 2] = True # enabled - o_conns[input_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state)) - - output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] - o_conns[output_idx, 0:2] = output_conns # in key, out key - o_conns[output_idx, 2] = True # enabled - o_conns[output_idx, 3:] = jax.device_get(gene_type.new_conn_attrs(state)) - - # repeat origin genome for P times to create population - pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) - pop_conns = np.tile(o_conns, (state.P, 1, 1)) - - return jax.device_put([pop_nodes, pop_conns]) - - -def count(nodes: Array, cons: Array): - """ - Count how many nodes and connections are in the genome. - """ - node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) - cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) - return node_cnt, cons_cnt - - -def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]: - """ - Add a new node to the genome. - The new node will place at the first NaN row. - """ - exist_keys = nodes[:, 0] - idx = fetch_first(jnp.isnan(exist_keys)) - nodes = nodes.at[idx, 0].set(new_key) - nodes = nodes.at[idx, 1:].set(attrs) - return nodes, cons - - -def delete_node(nodes: Array, cons: Array, node_key: Array) -> Tuple[Array, Array]: - """ - Delete a node from the genome. Only delete the node, regardless of connections. - Delete the node by its key. - """ - node_keys = nodes[:, 0] - idx = fetch_first(node_keys == node_key) - return delete_node_by_idx(nodes, cons, idx) - - -def delete_node_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]: - """ - Delete a node from the genome. Only delete the node, regardless of connections. - Delete the node by its idx. - """ - nodes = nodes.at[idx].set(np.nan) - return nodes, cons - - -def add_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array, enable: bool, attrs: Array) -> Tuple[ - Array, Array]: - """ - Add a new connection to the genome. - The new connection will place at the first NaN row. - """ - con_keys = cons[:, 0] - idx = fetch_first(jnp.isnan(con_keys)) - cons = cons.at[idx, 0:3].set(jnp.array([i_key, o_key, enable])) - cons = cons.at[idx, 3:].set(attrs) - return nodes, cons - - -def delete_connection(nodes: Array, cons: Array, i_key: Array, o_key: Array) -> Tuple[Array, Array]: - """ - Delete a connection from the genome. - Delete the connection by its input and output node keys. - """ - idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key)) - return delete_connection_by_idx(nodes, cons, idx) - - -def delete_connection_by_idx(nodes: Array, cons: Array, idx: Array) -> Tuple[Array, Array]: - """ - Delete a connection from the genome. - Delete the connection by its idx. - """ - cons = cons.at[idx].set(np.nan) - return nodes, cons diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py deleted file mode 100644 index 47849fd..0000000 --- a/algorithm/neat/genome/mutate.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import Dict, Tuple, Type - -import jax -from jax import Array, numpy as jnp, vmap - -from algorithm import State -from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx -from .graph import check_cycles -from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections -from ..gene import BaseGene - - -def create_mutate(config: Dict, gene_type: Type[BaseGene]): - """ - Create function to mutate a single genome - """ - - def mutate_structure(state: State, randkey, nodes, conns, new_node_key): - - def mutate_add_node(key_, nodes_, conns_): - i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_) - - def nothing(): - return nodes_, conns_ - - def successful_add_node(): - # disable the connection - aux_nodes, aux_conns = nodes_, conns_ - - # set enable to false - aux_conns = aux_conns.at[idx, 2].set(False) - - # add a new node - aux_nodes, aux_conns = add_node(aux_nodes, aux_conns, new_node_key, gene_type.new_node_attrs(state)) - - # add two new connections - aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, i_key, new_node_key, True, - gene_type.new_conn_attrs(state)) - aux_nodes, aux_conns = add_connection(aux_nodes, aux_conns, new_node_key, o_key, True, - gene_type.new_conn_attrs(state)) - - return aux_nodes, aux_conns - - # if from_idx == I_INT, that means no connection exist, do nothing - new_nodes, new_conns = jax.lax.cond(idx == I_INT, nothing, successful_add_node) - - return new_nodes, new_conns - - def mutate_delete_node(key_, nodes_, conns_): - # TODO: Do we really need to delete a node? - # randomly choose a node - key, idx = choice_node_key(key_, nodes_, config['input_idx'], config['output_idx'], - allow_input_keys=False, allow_output_keys=False) - def nothing(): - return nodes_, conns_ - - def successful_delete_node(): - # delete the node - aux_nodes, aux_cons = delete_node_by_idx(nodes_, conns_, idx) - - # delete all connections - aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None], - jnp.nan, aux_cons) - - return aux_nodes, aux_cons - - return jax.lax.cond(idx == I_INT, nothing, successful_delete_node) - - def mutate_add_conn(key_, nodes_, conns_): - # randomly choose two nodes - k1_, k2_ = jax.random.split(key_, num=2) - i_key, from_idx = choice_node_key(k1_, nodes_, config['input_idx'], config['output_idx'], - allow_input_keys=True, allow_output_keys=True) - o_key, to_idx = choice_node_key(k2_, nodes_, config['input_idx'], config['output_idx'], - allow_input_keys=False, allow_output_keys=True) - - con_idx = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)) - - def nothing(): - return nodes_, conns_ - - def successful(): - new_nodes, new_cons = add_connection(nodes_, conns_, i_key, o_key, True, gene_type.new_conn_attrs(state)) - return new_nodes, new_cons - - def already_exist(): - new_cons = conns_.at[con_idx, 2].set(True) - return nodes_, new_cons - - is_already_exist = con_idx != I_INT - - if config['network_type'] == 'feedforward': - u_cons = unflatten_connections(nodes_, conns_) - cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) - is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx) - - choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) - return jax.lax.switch(choice, [already_exist, nothing, successful]) - - elif config['network_type'] == 'recurrent': - return jax.lax.cond(is_already_exist, already_exist, successful) - - else: - raise ValueError(f"Invalid network type: {config['network_type']}") - - def mutate_delete_conn(key_, nodes_, conns_): - # randomly choose a connection - i_key, o_key, idx = choice_connection_key(key_, nodes_, conns_) - - def nothing(): - return nodes_, conns_ - - def successfully_delete_connection(): - return delete_connection_by_idx(nodes_, conns_, idx) - - return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) - - k1, k2, k3, k4 = jax.random.split(randkey, num=4) - r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) - - def no(k, n, c): - return n, c - - nodes, conns = jax.lax.cond(r1 < config['node_add_prob'], mutate_add_node, no, k1, nodes, conns) - - nodes, conns = jax.lax.cond(r2 < config['node_delete_prob'], mutate_delete_node, no, k2, nodes, conns) - - nodes, conns = jax.lax.cond(r3 < config['conn_add_prob'], mutate_add_conn, no, k3, nodes, conns) - - nodes, conns = jax.lax.cond(r4 < config['conn_delete_prob'], mutate_delete_conn, no, k4, nodes, conns) - - return nodes, conns - - def mutate_values(state: State, randkey, nodes, conns): - k1, k2 = jax.random.split(randkey, num=2) - nodes_keys = jax.random.split(k1, num=nodes.shape[0]) - conns_keys = jax.random.split(k2, num=conns.shape[0]) - - nodes_attrs, conns_attrs = nodes[:, 1:], conns[:, 3:] - - new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys) - new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys) - - # nan nodes not changed - new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) - new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs) - - new_nodes = nodes.at[:, 1:].set(new_nodes_attrs) - new_conns = conns.at[:, 3:].set(new_conns_attrs) - - return new_nodes, new_conns - - def mutate(state, randkey, nodes, conns, new_node_key): - k1, k2 = jax.random.split(randkey) - - nodes, conns = mutate_structure(state, k1, nodes, conns, new_node_key) - nodes, conns = mutate_values(state, k2, nodes, conns) - - return nodes, conns - - return mutate - - -def choice_node_key(rand_key: Array, nodes: Array, - input_keys: Array, output_keys: Array, - allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: - """ - Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. - :param rand_key: - :param nodes: - :param input_keys: - :param output_keys: - :param allow_input_keys: - :param allow_output_keys: - :return: return its key and position(idx) - """ - - node_keys = nodes[:, 0] - mask = ~jnp.isnan(node_keys) - - if not allow_input_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys)) - - if not allow_output_keys: - mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys)) - - idx = fetch_random(rand_key, mask) - key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) - return key, idx - - -def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]: - """ - Randomly choose a connection key from the given connections. - :param rand_key: - :param nodes: - :param cons: - :return: i_key, o_key, idx - """ - - idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0])) - i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan) - o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan) - - return i_key, o_key, idx diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 1f15745..4e5ef1f 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -1,67 +1,84 @@ from typing import Type import jax -import jax.numpy as jnp +from jax import numpy as jnp, Array, vmap +import numpy as np -from algorithm import Algorithm, State -from .gene import BaseGene -from .genome import initialize_genomes -from .population import create_tell +from config import Config +from core import Algorithm, State, Gene, Genome +from .ga import crossover, create_mutate +from .species import update_species, create_speciate class NEAT(Algorithm): - def __init__(self, config, gene_type: Type[BaseGene]): - super().__init__() + + def __init__(self, config: Config, gene_type: Type[Gene]): self.config = config self.gene_type = gene_type - self.tell = create_tell(config, self.gene_type) - self.ask = None - self.forward = self.gene_type.create_forward(config) - self.forward_transform = self.gene_type.forward_transform + self.forward_func = None + self.tell_func = None + + def setup(self, randkey, state: State = State()): + """initialize the state of the algorithm""" + + input_idx = np.arange(self.config.basic.num_inputs) + output_idx = np.arange(self.config.basic.num_inputs, + self.config.basic.num_inputs + self.config.basic.num_outputs) - def setup(self, randkey, state=State()): state = state.update( - P=self.config['pop_size'], - N=self.config['maximum_nodes'], - C=self.config['maximum_conns'], - S=self.config['maximum_species'], + P=self.config.basic.pop_size, + N=self.config.neat.maximum_nodes, + C=self.config.neat.maximum_conns, + S=self.config.neat.maximum_species, NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes - input_idx=self.config['input_idx'], - output_idx=self.config['output_idx'], - max_stagnation=self.config['max_stagnation'], - species_elitism=self.config['species_elitism'], - spawn_number_change_rate=self.config['spawn_number_change_rate'], - genome_elitism=self.config['genome_elitism'], - survival_threshold=self.config['survival_threshold'], - compatibility_threshold=self.config['compatibility_threshold'], + max_stagnation=self.config.neat.max_stagnation, + species_elitism=self.config.neat.species_elitism, + spawn_number_change_rate=self.config.neat.spawn_number_change_rate, + genome_elitism=self.config.neat.genome_elitism, + survival_threshold=self.config.neat.survival_threshold, + compatibility_threshold=self.config.neat.compatibility_threshold, + compatibility_disjoint=self.config.neat.compatibility_disjoint, + compatibility_weight=self.config.neat.compatibility_weight, + + input_idx=input_idx, + output_idx=output_idx, ) - state = self.gene_type.setup(state, self.config) + state = self.gene_type.setup(self.config.gene, state) + pop_genomes = self._initialize_genomes(state) - randkey = randkey - pop_nodes, pop_conns = initialize_genomes(state, self.gene_type) - species_info = jnp.full((state.S, 4), jnp.nan, - dtype=jnp.float32) # (species_key, best_fitness, last_improved, size) - species_info = species_info.at[0, :].set([0, -jnp.inf, 0, state.P]) + species_keys = np.full((state.S,), np.nan, dtype=np.float32) + best_fitness = np.full((state.S,), np.nan, dtype=np.float32) + last_improved = np.full((state.S,), np.nan, dtype=np.float32) + member_count = np.full((state.S,), np.nan, dtype=np.float32) idx2species = jnp.zeros(state.P, dtype=jnp.float32) + + species_keys[0] = 0 + best_fitness[0] = -np.inf + last_improved[0] = 0 + member_count[0] = state.P + center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32) center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32) - center_nodes = center_nodes.at[0, :, :].set(pop_nodes[0, :, :]) - center_conns = center_conns.at[0, :, :].set(pop_conns[0, :, :]) + center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :]) + center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :]) + center_genomes = vmap(Genome)(center_nodes, center_conns) + generation = 0 next_node_key = max(*state.input_idx, *state.output_idx) + 2 next_species_key = 1 state = state.update( randkey=randkey, - pop_nodes=pop_nodes, - pop_conns=pop_conns, - species_info=species_info, + pop_genomes=pop_genomes, + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, idx2species=idx2species, - center_nodes=center_nodes, - center_conns=center_conns, + center_genomes=center_genomes, # avoid jax auto cast from int to float. that would cause re-compilation. generation=jnp.asarray(generation, dtype=jnp.int32), @@ -69,7 +86,112 @@ class NEAT(Algorithm): next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32), ) - # move to device - state = jax.device_put(state) + self.forward_func = self.gene_type.create_forward(state, self.config.gene) + self.tell_func = self._create_tell() - return state \ No newline at end of file + return jax.device_put(state) + + def ask(self, state: State): + """require the population to be evaluated""" + return state.pop_genomes + + def tell(self, state: State, fitness): + """update the state of the algorithm""" + return self.tell_func(state, fitness) + + def forward(self, inputs: Array, transformed: Array): + """the forward function of a single forward transformation""" + return self.forward_func(inputs, transformed) + + def forward_transform(self, state: State, genome: Genome): + """create the forward transformation of a genome""" + return self.gene_type.forward_transform(state, genome) + + def _initialize_genomes(self, state): + o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes + o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections + + input_idx = state.input_idx + output_idx = state.output_idx + new_node_key = max([*input_idx, *output_idx]) + 1 + + o_nodes[input_idx, 0] = input_idx + o_nodes[output_idx, 0] = output_idx + o_nodes[new_node_key, 0] = new_node_key + o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state) + o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state) + + input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] + o_conns[input_idx, 0:2] = input_conns # in key, out key + o_conns[input_idx, 2] = True # enabled + o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state) + + output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] + o_conns[output_idx, 0:2] = output_conns # in key, out key + o_conns[output_idx, 2] = True # enabled + o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state) + + # repeat origin genome for P times to create population + pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) + pop_conns = np.tile(o_conns, (state.P, 1, 1)) + + return vmap(Genome)(pop_nodes, pop_conns) + + def _create_tell(self): + mutate = create_mutate(self.config.neat, self.gene_type) + + def create_next_generation(state, randkey, winner, loser, elite_mask): + # prepare random keys + pop_size = state.idx2species.shape[0] + new_node_keys = jnp.arange(pop_size) + state.next_node_key + + k1, k2 = jax.random.split(randkey, 2) + crossover_rand_keys = jax.random.split(k1, pop_size) + mutate_rand_keys = jax.random.split(k2, pop_size) + + # batch crossover + wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner] + lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser] + n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc)) + + # batch mutation + mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0)) + m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes) + pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns) + + # update next node key + all_nodes_keys = pop_nodes[:, :, 0] + max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + next_node_key = max_node_key + 1 + + return state.update( + pop_genomes=Genome(pop_nodes, pop_conns), + next_node_key=next_node_key, + ) + + speciate = create_speciate(self.gene_type) + + def tell(state, fitness): + """ + Main update function in NEAT. + """ + + k1, k2, randkey = jax.random.split(state.randkey, 3) + + state = state.update( + generation=state.generation + 1, + randkey=randkey + ) + + state, winner, loser, elite_mask = update_species(state, k1, fitness) + + state = create_next_generation(state, k2, winner, loser, elite_mask) + + state = speciate(state) + + return state + + return tell diff --git a/algorithm/neat/population.py b/algorithm/neat/population.py deleted file mode 100644 index a89d178..0000000 --- a/algorithm/neat/population.py +++ /dev/null @@ -1,363 +0,0 @@ -from typing import Type - -import jax -from jax import numpy as jnp, vmap - -from algorithm.utils import rank_elements, fetch_first -from .genome import create_mutate, create_distance, crossover -from .gene import BaseGene - - -def create_tell(config, gene_type: Type[BaseGene]): - mutate = create_mutate(config, gene_type) - distance = create_distance(config, gene_type) - - def update_species(state, randkey, fitness): - # update the fitness of each species - species_fitness = update_species_fitness(state, fitness) - - # stagnation species - state, species_fitness = stagnation(state, species_fitness) - - # sort species_info by their fitness. (push nan to the end) - sort_indices = jnp.argsort(species_fitness)[::-1] - - state = state.update( - species_info=state.species_info[sort_indices], - center_nodes=state.center_nodes[sort_indices], - center_conns=state.center_conns[sort_indices], - ) - - # decide the number of members of each species by their fitness - spawn_number = cal_spawn_numbers(state) - - # crossover info - winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness) - - return state, winner, loser, elite_mask - - def update_species_fitness(state, fitness): - """ - obtain the fitness of the species by the fitness of each individual. - use max criterion. - """ - - def aux_func(idx): - species_key = state.species_info[idx, 0] - s_fitness = jnp.where(state.idx2species == species_key, fitness, -jnp.inf) - f = jnp.max(s_fitness) - return f - - return vmap(aux_func)(jnp.arange(state.species_info.shape[0])) - - def stagnation(state, species_fitness): - """ - stagnation species. - those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. - elitism species never stagnation - """ - - def aux_func(idx): - s_fitness = species_fitness[idx] - species_key, best_score, last_update, members_count = state.species_info[idx] - st = (s_fitness <= best_score) & (state.generation - last_update > state.max_stagnation) - last_update = jnp.where(s_fitness > best_score, state.generation, last_update) - best_score = jnp.where(s_fitness > best_score, s_fitness, best_score) - # stagnation condition - return st, jnp.array([species_key, best_score, last_update, members_count]) - - spe_st, species_info = vmap(aux_func)(jnp.arange(species_fitness.shape[0])) - - # elite species will not be stagnation - species_rank = rank_elements(species_fitness) - spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation - - # set stagnation species to nan - species_info = jnp.where(spe_st[:, None], jnp.nan, species_info) - center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_nodes) - center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_conns) - species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) - - state = state.update( - species_info=species_info, - center_nodes=center_nodes, - center_conns=center_conns, - ) - - return state, species_fitness - - def cal_spawn_numbers(state): - """ - decide the number of members of each species by their fitness rank. - the species with higher fitness will have more members - Linear ranking selection - e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] - """ - - is_species_valid = ~jnp.isnan(state.species_info[:, 0]) - valid_species_num = jnp.sum(is_species_valid) - denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 - - rank_score = valid_species_num - jnp.arange(state.species_info.shape[0]) # obtain [3, 2, 1] - spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 - - target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member - - # Avoid too much variation of numbers in a species - previous_size = state.species_info[:, 3].astype(jnp.int32) - spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate - # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) - spawn_number = spawn_number.astype(jnp.int32) - - # spawn_number = target_spawn_number.astype(jnp.int32) - - # must control the sum of spawn_number to be equal to pop_size - error = state.P - jnp.sum(spawn_number) - spawn_number = spawn_number.at[0].add( - error) # add error to the first species to control the sum of spawn_number - - return spawn_number - - def create_crossover_pair(state, randkey, spawn_number, fitness): - species_size = state.species_info.shape[0] - pop_size = fitness.shape[0] - s_idx = jnp.arange(species_size) - p_idx = jnp.arange(pop_size) - - # def aux_func(key, idx): - def aux_func(key, idx): - members = state.idx2species == state.species_info[idx, 0] - members_num = jnp.sum(members) - - members_fitness = jnp.where(members, fitness, -jnp.inf) - sorted_member_indices = jnp.argsort(members_fitness)[::-1] - - elite_size = state.genome_elitism - survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32) - - select_pro = (p_idx < survive_size) / survive_size - fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) - - # elite - fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) - ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) - elite = jnp.where(p_idx < elite_size, True, False) - return fa, ma, elite - - fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) - - spawn_number_cum = jnp.cumsum(spawn_number) - - def aux_func(idx): - loc = jnp.argmax(idx < spawn_number_cum) - - # elite genomes are at the beginning of the species - idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) - return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] - - part1, part2, elite_mask = vmap(aux_func)(p_idx) - - is_part1_win = fitness[part1] >= fitness[part2] - winner = jnp.where(is_part1_win, part1, part2) - loser = jnp.where(is_part1_win, part2, part1) - - return winner, loser, elite_mask - - def create_next_generation(state, randkey, winner, loser, elite_mask): - # prepare random keys - pop_size = state.pop_nodes.shape[0] - new_node_keys = jnp.arange(pop_size) + state.next_node_key - - k1, k2 = jax.random.split(randkey, 2) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) - - # batch crossover - wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] # winner pop nodes, winner pop connections - lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] # loser pop nodes, loser pop connections - npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - - # batch mutation - mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0, 0)) - m_npn, m_npc = mutate_func(state, mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) - pop_conns = jnp.where(elite_mask[:, None, None], npc, m_npc) - - # update next node key - all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) - next_node_key = max_node_key + 1 - - return state.update( - pop_nodes=pop_nodes, - pop_conns=pop_conns, - next_node_key=next_node_key, - ) - - def speciate(state): - pop_size, species_size = state.pop_nodes.shape[0], state.center_nodes.shape[0] - - # prepare distance functions - o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0, 0)) # one to population - - # idx to specie key - idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species - - # the distance between genomes to its center genomes - o2c_distances = jnp.full((pop_size,), jnp.inf) - - # step 1: find new centers - def cond_func(carry): - i, i2s, cn, cc, o2c = carry - species_key = state.species_info[i, 0] - # jax.debug.print("{}, {}", i, species_key) - return (i < species_size) & (~jnp.isnan(species_key)) # current species is existing - - def body_func(carry): - i, i2s, cn, cc, o2c = carry - distances = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns) - - # find the closest one - closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - # jax.debug.print("closest_idx: {}", closest_idx) - - i2s = i2s.at[closest_idx].set(state.species_info[i, 0]) - cn = cn.at[i].set(state.pop_nodes[closest_idx]) - cc = cc.at[i].set(state.pop_conns[closest_idx]) - - # the genome with closest_idx will become the new center, thus its distance to center is 0. - o2c = o2c.at[closest_idx].set(0) - - return i + 1, i2s, cn, cc, o2c - - _, idx2specie, center_nodes, center_conns, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, - (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances)) - - # part 2: assign members to each species - def cond_func(carry): - i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key - current_species_existed = ~jnp.isnan(si[i, 0]) - not_all_assigned = jnp.any(jnp.isnan(i2s)) - not_reach_species_upper_bounds = i < species_size - return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) - - def body_func(carry): - i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_conns - - _, i2s, scn, scc, si, o2c, nsk = jax.lax.cond( - jnp.isnan(si[i, 0]), # whether the current species is existing or not - create_new_species, # if not existing, create a new specie - update_exist_specie, # if existing, update the specie - (i, i2s, cn, cc, si, o2c, nsk) - ) - - return i + 1, i2s, scn, scc, si, o2c, nsk - - def create_new_species(carry): - i, i2s, cn, cc, si, o2c, nsk = carry - - # pick the first one who has not been assigned to any species - idx = fetch_first(jnp.isnan(i2s)) - - # assign it to the new species - # [key, best score, last update generation, members_count] - si = si.at[i].set(jnp.array([nsk, -jnp.inf, state.generation, 0])) - i2s = i2s.at[idx].set(nsk) - o2c = o2c.at[idx].set(0) - - # update center genomes - cn = cn.at[i].set(state.pop_nodes[idx]) - cc = cc.at[i].set(state.pop_conns[idx]) - - i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) - - # when a new species is created, it needs to be updated, thus do not change i - return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key - - def update_exist_specie(carry): - i, i2s, cn, cc, si, o2c, nsk = carry - i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) - - # turn to next species - return i + 1, i2s, cn, cc, si, o2c, nsk - - def speciate_by_threshold(carry): - i, i2s, cn, cc, si, o2c = carry - - # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(state, cn[i], cc[i], state.pop_nodes, state.pop_conns) - close_enough_mask = o2p_distance < state.compatibility_threshold - - # when a genome is not assigned or the distance between its current center is bigger than this center - cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) - # jax.debug.print("{}", o2p_distance) - mask = close_enough_mask & cacheable_mask - - # update species info - i2s = jnp.where(mask, si[i, 0], i2s) - - # update distance between centers - o2c = jnp.where(mask, o2p_distance, o2c) - - return i2s, o2c - - # update idx2specie - _, idx2specie, center_nodes, center_conns, species_info, _, next_species_key = jax.lax.while_loop( - cond_func, - body_func, - (0, idx2specie, center_nodes, center_conns, state.species_info, o2c_distances, state.next_species_key) - ) - - # if there are still some pop genomes not assigned to any species, add them to the last genome - # this condition can only happen when the number of species is reached species upper bounds - idx2specie = jnp.where(jnp.isnan(idx2specie), species_info[-1, 0], idx2specie) - - # update members count - def count_members(idx): - key = species_info[idx, 0] - count = jnp.sum(idx2specie == key) - count = jnp.where(jnp.isnan(key), jnp.nan, count) - return count - - species_member_counts = vmap(count_members)(jnp.arange(species_size)) - species_info = species_info.at[:, 3].set(species_member_counts) - - return state.update( - idx2species=idx2specie, - center_nodes=center_nodes, - center_conns=center_conns, - species_info=species_info, - next_species_key=next_species_key - ) - - def tell(state, fitness): - """ - Main update function in NEAT. - """ - - k1, k2, randkey = jax.random.split(state.randkey, 3) - - state = state.update( - generation=state.generation + 1, - randkey=randkey - ) - - state, winner, loser, elite_mask = update_species(state, k1, fitness) - - state = create_next_generation(state, k2, winner, loser, elite_mask) - - state = speciate(state) - - return state - - return tell - - -def argmin_with_mask(arr, mask): - masked_arr = jnp.where(mask, arr, jnp.inf) - min_idx = jnp.argmin(masked_arr) - return min_idx diff --git a/algorithm/neat/species/__init__.py b/algorithm/neat/species/__init__.py new file mode 100644 index 0000000..d5d058d --- /dev/null +++ b/algorithm/neat/species/__init__.py @@ -0,0 +1 @@ +from .operations import update_species, create_speciate diff --git a/algorithm/neat/genome/distance.py b/algorithm/neat/species/distance.py similarity index 82% rename from algorithm/neat/genome/distance.py rename to algorithm/neat/species/distance.py index 0bd4e5a..9667e5a 100644 --- a/algorithm/neat/genome/distance.py +++ b/algorithm/neat/species/distance.py @@ -1,11 +1,11 @@ -from typing import Dict, Type +from typing import Type from jax import Array, numpy as jnp, vmap -from ..gene import BaseGene +from core import Gene -def create_distance(config: Dict, gene_type: Type[BaseGene]): +def create_distance(gene_type: Type[Gene]): def node_distance(state, nodes1: Array, nodes2: Array): """ Calculate the distance between nodes of two genomes. @@ -35,8 +35,7 @@ def create_distance(config: Dict, gene_type: Type[BaseGene]): hnd = jnp.where(jnp.isnan(hnd), 0, hnd) homologous_distance = jnp.sum(hnd * intersect_mask) - val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[ - 'compatibility_weight'] + val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division @@ -64,13 +63,11 @@ def create_distance(config: Dict, gene_type: Type[BaseGene]): hcd = jnp.where(jnp.isnan(hcd), 0, hcd) homologous_distance = jnp.sum(hcd * intersect_mask) - val = non_homologous_cnt * config['compatibility_disjoint'] + homologous_distance * config[ - 'compatibility_weight'] + val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight return jnp.where(max_cnt == 0, 0, val / max_cnt) - def distance(state, nodes1, conns1, nodes2, conns2): - return node_distance(state, nodes1, nodes2) + connection_distance(state, conns1, conns2) + def distance(state, genome1, genome2): + return node_distance(state, genome1.nodes, genome2.nodes) + connection_distance(state, genome1.conns, genome2.conns) return distance - diff --git a/algorithm/neat/species/operations.py b/algorithm/neat/species/operations.py new file mode 100644 index 0000000..7921016 --- /dev/null +++ b/algorithm/neat/species/operations.py @@ -0,0 +1,334 @@ +from typing import Type + +import jax +from jax import numpy as jnp, vmap + +from core import Gene, Genome +from utils import rank_elements, fetch_first +from .distance import create_distance + + +def update_species(state, randkey, fitness): + # update the fitness of each species + species_fitness = update_species_fitness(state, fitness) + + # stagnation species + state, species_fitness = stagnation(state, species_fitness) + + # sort species_info by their fitness. (push nan to the end) + sort_indices = jnp.argsort(species_fitness)[::-1] + + center_nodes = state.center_genomes.nodes[sort_indices] + center_conns = state.center_genomes.conns[sort_indices] + + state = state.update( + species_keys=state.species_keys[sort_indices], + best_fitness=state.best_fitness[sort_indices], + last_improved=state.last_improved[sort_indices], + member_count=state.member_count[sort_indices], + center_genomes=Genome(center_nodes, center_conns), + ) + + # decide the number of members of each species by their fitness + spawn_number = cal_spawn_numbers(state) + + # crossover info + winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness) + + return state, winner, loser, elite_mask + + +def update_species_fitness(state, fitness): + """ + obtain the fitness of the species by the fitness of each individual. + use max criterion. + """ + + def aux_func(idx): + s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf) + f = jnp.max(s_fitness) + return f + + return vmap(aux_func)(jnp.arange(state.species_keys.shape[0])) + + +def stagnation(state, species_fitness): + """ + stagnation species. + those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. + elitism species never stagnation + """ + + def aux_func(idx): + s_fitness = species_fitness[idx] + sk, bf, li = state.species_keys[idx], state.best_fitness[idx], state.last_improved[idx] + st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation) + li = jnp.where(s_fitness > bf, state.generation, li) + bf = jnp.where(s_fitness > bf, s_fitness, bf) + + return st, sk, bf, li + + spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0])) + + # elite species will not be stagnation + species_rank = rank_elements(species_fitness) + spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation + + # set stagnation species to nan + species_keys = jnp.where(spe_st, jnp.nan, species_keys) + best_fitness = jnp.where(spe_st, jnp.nan, best_fitness) + last_improved = jnp.where(spe_st, jnp.nan, last_improved) + member_count = jnp.where(spe_st, jnp.nan, state.member_count) + species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) + + center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes) + center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns) + + state = state.update( + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, + center_genomes=state.center_genomes.update(center_nodes, center_conns) + ) + + return state, species_fitness + + +def cal_spawn_numbers(state): + """ + decide the number of members of each species by their fitness rank. + the species with higher fitness will have more members + Linear ranking selection + e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] + """ + + is_species_valid = ~jnp.isnan(state.species_keys) + valid_species_num = jnp.sum(is_species_valid) + denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + + rank_score = valid_species_num - jnp.arange(state.species_keys.shape[0]) # obtain [3, 2, 1] + spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] + spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + + target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member + + # Avoid too much variation of numbers in a species + previous_size = state.member_count + spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate + # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) + spawn_number = spawn_number.astype(jnp.int32) + + # must control the sum of spawn_number to be equal to pop_size + error = state.P - jnp.sum(spawn_number) + spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number + + return spawn_number + + +def create_crossover_pair(state, randkey, spawn_number, fitness): + species_size = state.species_keys.shape[0] + pop_size = fitness.shape[0] + s_idx = jnp.arange(species_size) + p_idx = jnp.arange(pop_size) + + # def aux_func(key, idx): + def aux_func(key, idx): + members = state.idx2species == state.species_keys[idx] + members_num = jnp.sum(members) + + members_fitness = jnp.where(members, fitness, -jnp.inf) + sorted_member_indices = jnp.argsort(members_fitness)[::-1] + + elite_size = state.genome_elitism + survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32) + + select_pro = (p_idx < survive_size) / survive_size + fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro) + + # elite + fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa) + ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma) + elite = jnp.where(p_idx < elite_size, True, False) + return fa, ma, elite + + fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx) + + spawn_number_cum = jnp.cumsum(spawn_number) + + def aux_func(idx): + loc = jnp.argmax(idx < spawn_number_cum) + + # elite genomes are at the beginning of the species + idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) + return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + + part1, part2, elite_mask = vmap(aux_func)(p_idx) + + is_part1_win = fitness[part1] >= fitness[part2] + winner = jnp.where(is_part1_win, part1, part2) + loser = jnp.where(is_part1_win, part2, part1) + + return winner, loser, elite_mask + + +def create_speciate(gene_type: Type[Gene]): + distance = create_distance(gene_type) + + def speciate(state): + pop_size, species_size = state.idx2species.shape[0], state.species_keys.shape[0] + + # prepare distance functions + o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population + + # idx to specie key + idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species + + # the distance between genomes to its center genomes + o2c_distances = jnp.full((pop_size,), jnp.inf) + + # step 1: find new centers + def cond_func(carry): + i, i2s, cgs, o2c = carry + + return (i < species_size) & (~jnp.isnan(state.species_keys[i])) # current species is existing + + def body_func(carry): + i, i2s, cgs, o2c = carry + + distances = o2p_distance_func(state, Genome(cgs.nodes[i], cgs.conns[i]), state.pop_genomes) + + # find the closest one + closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) + # jax.debug.print("closest_idx: {}", closest_idx) + + i2s = i2s.at[closest_idx].set(state.species_keys[i]) + cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[closest_idx]) + cc = cgs.conns.at[i].set(state.pop_genomes.conns[closest_idx]) + + # the genome with closest_idx will become the new center, thus its distance to center is 0. + o2c = o2c.at[closest_idx].set(0) + + return i + 1, i2s, Genome(cn, cc), o2c + + _, idx2species, center_genomes, o2c_distances = \ + jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) + + state = state.update( + idx2species=idx2species, + center_genomes=center_genomes, + ) + + # part 2: assign members to each species + def cond_func(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + current_species_existed = ~jnp.isnan(sk[i]) + not_all_assigned = jnp.any(jnp.isnan(i2s)) + not_reach_species_upper_bounds = i < species_size + return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + + def body_func(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + _, i2s, cgs, sk, o2c, nsk = jax.lax.cond( + jnp.isnan(sk[i]), # whether the current species is existing or not + create_new_species, # if not existing, create a new specie + update_exist_specie, # if existing, update the specie + (i, i2s, cgs, sk, o2c, nsk) + ) + + return i + 1, i2s, cgs, sk, o2c, nsk + + def create_new_species(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + # pick the first one who has not been assigned to any species + idx = fetch_first(jnp.isnan(i2s)) + + # assign it to the new species + # [key, best score, last update generation, members_count] + sk = sk.at[i].set(nsk) + i2s = i2s.at[idx].set(nsk) + o2c = o2c.at[idx].set(0) + + # update center genomes + cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[idx]) + cc = cgs.conns.at[i].set(state.pop_genomes.conns[idx]) + cgs = Genome(cn, cc) + + i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) + + # when a new species is created, it needs to be updated, thus do not change i + return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key + + def update_exist_specie(carry): + i, i2s, cgs, sk, o2c, nsk = carry + + i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) + + # turn to next species + return i + 1, i2s, cgs, sk, o2c, nsk + + def speciate_by_threshold(i, i2s, cgs, sk, o2c): + # distance between such center genome and ppo genomes + + center = Genome(cgs.nodes[i], cgs.conns[i]) + o2p_distance = o2p_distance_func(state, center, state.pop_genomes) + close_enough_mask = o2p_distance < state.compatibility_threshold + + # when a genome is not assigned or the distance between its current center is bigger than this center + cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) + # jax.debug.print("{}", o2p_distance) + mask = close_enough_mask & cacheable_mask + + # update species info + i2s = jnp.where(mask, sk[i], i2s) + + # update distance between centers + o2c = jnp.where(mask, o2p_distance, o2c) + + return i2s, o2c + + # update idx2species + _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop( + cond_func, + body_func, + (0, state.idx2species, state.center_genomes, state.species_keys, o2c_distances, state.next_species_key) + ) + + # if there are still some pop genomes not assigned to any species, add them to the last genome + # this condition can only happen when the number of species is reached species upper bounds + idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) + + # complete info of species which is created in this generation + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) + best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) + last_improved = jnp.where(new_created_mask, state.generation, state.last_improved) + + # update members count + def count_members(idx): + key = species_keys[idx] + count = jnp.sum(idx2species == key) + count = jnp.where(jnp.isnan(key), jnp.nan, count) + return count + + member_count = vmap(count_members)(jnp.arange(species_size)) + + return state.update( + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + members_count=member_count, + idx2species=idx2species, + center_genomes=center_genomes, + next_species_key=next_species_key + ) + + return speciate + + +def argmin_with_mask(arr, mask): + masked_arr = jnp.where(mask, arr, jnp.inf) + min_idx = jnp.argmin(masked_arr) + return min_idx diff --git a/config/__init__.py b/config/__init__.py index dfb91b6..473966f 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -1 +1,2 @@ -from .config import Configer +from .config import * + diff --git a/config/config.py b/config/config.py index 3ee9c7d..bc351f9 100644 --- a/config/config.py +++ b/config/config.py @@ -1,70 +1,103 @@ -import os -import warnings -import configparser - -import numpy as np +from dataclasses import dataclass +from typing import Union -class Configer: +@dataclass(frozen=True) +class BasicConfig: + seed: int = 42 + fitness_target: float = 1 + generation_limit: int = 1000 + num_inputs: int = 2 + num_outputs: int = 1 + pop_size: int = 100 - @classmethod - def __load_default_config(cls): - par_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_path = os.path.join(par_dir, "default_config.ini") - return cls.__load_config(default_config_path) + def __post_init__(self): + assert self.num_inputs > 0, "the inputs number of the problem must be greater than 0" + assert self.num_outputs > 0, "the outputs number of the problem must be greater than 0" + assert self.pop_size > 0, "the population size must be greater than 0" - @classmethod - def __load_config(cls, config_path): - c = configparser.ConfigParser() - c.read(config_path) - config = {} - for section in c.sections(): - for key, value in c.items(section): - config[key] = eval(value) +@dataclass(frozen=True) +class NeatConfig: + network_type: str = "feedforward" + activate_times: Union[int, None] = None # None means the network is feedforward + maximum_nodes: int = 100 + maximum_conns: int = 50 + maximum_species: int = 10 - return config + # genome config + compatibility_disjoint: float = 1 + compatibility_weight: float = 0.5 + conn_add: float = 0.4 + conn_delete: float = 0.4 + node_add: float = 0.2 + node_delete: float = 0.2 - @classmethod - def __check_redundant_config(cls, default_config, config): - for key in config: - if key not in default_config: - warnings.warn(f"Redundant config: {key} in config!") + # species config + compatibility_threshold: float = 3.0 + species_elitism: int = 2 + max_stagnation: int = 15 + genome_elitism: int = 2 + survival_threshold: float = 0.2 + min_species_size: int = 1 + spawn_number_change_rate: float = 0.5 - @classmethod - def __complete_config(cls, default_config, config): - for key in default_config: - if key not in config: - config[key] = default_config[key] - - @classmethod - def load_config(cls, config_path=None): - default_config = cls.__load_default_config() - if config_path is None: - config = {} - elif not os.path.exists(config_path): - warnings.warn(f"config file {config_path} not exist!") - config = {} + def __post_init__(self): + assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent" + if self.network_type == "feedforward": + assert self.activate_times is None, "the activate times of feedforward network must be None" else: - config = cls.__load_config(config_path) + assert isinstance(self.activate_times, int), "the activate times of recurrent network must be int" + assert self.activate_times > 0, "the activate times of recurrent network must be greater than 0" - cls.__check_redundant_config(default_config, config) - cls.__complete_config(default_config, config) + assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0" + assert self.maximum_conns > 0, "the maximum connections must be greater than 0" + assert self.maximum_species > 0, "the maximum species must be greater than 0" - cls.refactor_activation(config) - cls.refactor_aggregation(config) + assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0" + assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0" + assert self.conn_add > 0, "the connection add probability must be greater than 0" + assert self.conn_delete > 0, "the connection delete probability must be greater than 0" + assert self.node_add > 0, "the node add probability must be greater than 0" + assert self.node_delete > 0, "the node delete probability must be greater than 0" - config['input_idx'] = np.arange(config['num_inputs']) - config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs']) + assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0" + assert self.species_elitism > 0, "the species elitism must be greater than 0" + assert self.max_stagnation > 0, "the max stagnation must be greater than 0" + assert self.genome_elitism > 0, "the genome elitism must be greater than 0" + assert self.survival_threshold > 0, "the survival threshold must be greater than 0" + assert self.min_species_size > 0, "the min species size must be greater than 0" + assert self.spawn_number_change_rate > 0, "the spawn number change rate must be greater than 0" - return config - @classmethod - def refactor_activation(cls, config): - config['activation_default'] = 0 - config['activation_options'] = np.arange(len(config['activation_option_names'])) +@dataclass(frozen=True) +class HyperNeatConfig: + below_threshold: float = 0.2 + max_weight: float = 3 + activation: str = "sigmoid" + aggregation: str = "sum" + activate_times: int = 5 - @classmethod - def refactor_aggregation(cls, config): - config['aggregation_default'] = 0 - config['aggregation_options'] = np.arange(len(config['aggregation_option_names'])) + def __post_init__(self): + assert self.below_threshold > 0, "the below threshold must be greater than 0" + assert self.max_weight > 0, "the max weight must be greater than 0" + assert self.activate_times > 0, "the activate times must be greater than 0" + + +@dataclass(frozen=True) +class GeneConfig: + pass + + +@dataclass(frozen=True) +class SubstrateConfig: + pass + + +@dataclass(frozen=True) +class Config: + basic: BasicConfig = BasicConfig() + neat: NeatConfig = NeatConfig() + hyper_neat: HyperNeatConfig = HyperNeatConfig() + gene: GeneConfig = GeneConfig() + substrate: SubstrateConfig = SubstrateConfig() diff --git a/config/default_config.ini b/config/default_config.ini index e8be118..921776d 100644 --- a/config/default_config.ini +++ b/config/default_config.ini @@ -1,8 +1,6 @@ [basic] random_seed = 0 generation_limit = 1000 - -[problem] fitness_threshold = 3.9999 num_inputs = 2 num_outputs = 1 @@ -14,6 +12,13 @@ maximum_nodes = 50 maximum_conns = 50 maximum_species = 10 +compatibility_disjoint = 1.0 +compatibility_weight = 0.5 +conn_add_prob = 0.4 +conn_delete_prob = 0 +node_add_prob = 0.2 +node_delete_prob = 0 + [hyperneat] below_threshold = 0.2 max_weight = 3 @@ -26,17 +31,6 @@ input_coors = [[-1, 1], [0, 1], [1, 1]] hidden_coors = [[-1, 0], [0, 0], [1, 0]] output_coors = [[0, -1]] -[population] -pop_size = 10 - -[genome] -compatibility_disjoint = 1.0 -compatibility_weight = 0.5 -conn_add_prob = 0.4 -conn_delete_prob = 0 -node_add_prob = 0.2 -node_delete_prob = 0 - [species] compatibility_threshold = 3.0 species_elitism = 2 diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..1bf1c8c --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,5 @@ +from .algorithm import Algorithm +from .state import State +from .genome import Genome +from .gene import Gene + diff --git a/core/algorithm.py b/core/algorithm.py new file mode 100644 index 0000000..51d2fec --- /dev/null +++ b/core/algorithm.py @@ -0,0 +1,28 @@ +from jax import Array +from .state import State +from .genome import Genome + +EMPTY = lambda *args: args + + +class Algorithm: + + def setup(self, randkey, state: State = State()): + """initialize the state of the algorithm""" + pass + + def ask(self, state: State): + """require the population to be evaluated""" + pass + + def tell(self, state: State, fitness): + """update the state of the algorithm""" + pass + + def forward(self, inputs: Array, transformed: Array): + """the forward function of a single forward transformation""" + pass + + def forward_transform(self, state: State, genome: Genome): + """create the forward transformation of a genome""" + pass diff --git a/core/gene.py b/core/gene.py new file mode 100644 index 0000000..03a9bfe --- /dev/null +++ b/core/gene.py @@ -0,0 +1,46 @@ +from jax import Array, numpy as jnp + +from config import GeneConfig +from .state import State +from .genome import Genome + + +class Gene: + node_attrs = [] + conn_attrs = [] + + @staticmethod + def setup(config: GeneConfig, state: State): + return state + + @staticmethod + def new_node_attrs(state: State): + return jnp.zeros(0) + + @staticmethod + def new_conn_attrs(state: State): + return jnp.zeros(0) + + @staticmethod + def mutate_node(state: State, attrs: Array, randkey: Array): + return attrs + + @staticmethod + def mutate_conn(state: State, attrs: Array, randkey: Array): + return attrs + + @staticmethod + def distance_node(state: State, node1: Array, node2: Array): + return node1 + + @staticmethod + def distance_conn(state: State, conn1: Array, conn2: Array): + return conn1 + + @staticmethod + def forward_transform(state: State, genome: Genome): + return jnp.zeros(0) # transformed + @staticmethod + def create_forward(state: State, config: GeneConfig): + return lambda *args: args # forward function + diff --git a/core/genome.py b/core/genome.py new file mode 100644 index 0000000..de5853b --- /dev/null +++ b/core/genome.py @@ -0,0 +1,77 @@ +from jax.tree_util import register_pytree_node_class +from jax import numpy as jnp + +from utils.tools import fetch_first + + +@register_pytree_node_class +class Genome: + + def __init__(self, nodes, conns): + self.nodes = nodes + self.conns = conns + + def update(self, nodes, conns): + return self.__class__(nodes, conns) + + def update_nodes(self, nodes): + return self.update(nodes, self.conns) + + def update_conns(self, conns): + return self.update(self.nodes, conns) + + def count(self): + """Count how many nodes and connections are in the genome.""" + nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0])) + conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0])) + return nodes_cnt, conns_cnt + + def add_node(self, new_key: int, attrs): + """ + Add a new node to the genome. + The new node will place at the first NaN row. + """ + exist_keys = self.nodes[:, 0] + pos = fetch_first(jnp.isnan(exist_keys)) + new_nodes = self.nodes.at[pos, 0].set(new_key) + new_nodes = new_nodes.at[pos, 1:].set(attrs) + return self.update_nodes(new_nodes) + + def delete_node_by_pos(self, pos): + """ + Delete a node from the genome. + Delete the node by its pos in nodes. + """ + nodes = self.nodes.at[pos].set(jnp.nan) + return self.update_nodes(nodes) + + def add_conn(self, i_key, o_key, enable: bool, attrs): + """ + Add a new connection to the genome. + The new connection will place at the first NaN row. + """ + con_keys = self.conns[:, 0] + pos = fetch_first(jnp.isnan(con_keys)) + new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) + new_conns = new_conns.at[pos, 3:].set(attrs) + return self.update_conns(new_conns) + + def delete_conn_by_pos(self, pos): + """ + Delete a connection from the genome. + Delete the connection by its idx. + """ + conns = self.conns.at[pos].set(jnp.nan) + return self.update_conns(conns) + + def tree_flatten(self): + children = self.nodes, self.conns + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) + + def __repr__(self): + return f"Genome(nodes={self.nodes}, conns={self.conns})" diff --git a/algorithm/state.py b/core/state.py similarity index 93% rename from algorithm/state.py rename to core/state.py index b24ff62..9e3932e 100644 --- a/algorithm/state.py +++ b/core/state.py @@ -26,4 +26,4 @@ class State: @classmethod def tree_unflatten(cls, aux_data, children): - return cls(**dict(zip(aux_data, children))) + return cls(**dict(zip(aux_data, children))) \ No newline at end of file diff --git a/examples/a.py b/examples/a.py index 6ffb01d..5932ef6 100644 --- a/examples/a.py +++ b/examples/a.py @@ -1,11 +1,28 @@ -import numpy as np -import jax.numpy as jnp +import jax +from jax import numpy as jnp -a = jnp.zeros((5, 5)) -k1 = jnp.array([1, 2, 3]) -k2 = jnp.array([2, 3, 4]) -v = jnp.array([1, 1, 1]) +from config import Config +from core import Genome -a = a.at[k1, k2].set(v) +config = Config() +from dataclasses import asdict + +print(asdict(config)) + +pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3)) +pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5)) + +pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns) +pop_genomes2 = Genome(pop_nodes, pop_conns) + +print(pop_genomes) +print(pop_genomes[0]) + +@jax.vmap +def pop_cnts(genome): + return genome.count() + +cnts = pop_cnts(pop_genomes) + +print(cnts) -print(a) diff --git a/examples/b.py b/examples/b.py new file mode 100644 index 0000000..42cabf0 --- /dev/null +++ b/examples/b.py @@ -0,0 +1,19 @@ +from enum import Enum +from jax import jit + +class NetworkType(Enum): + ANN = 0 + SNN = 1 + LSTM = 2 + + + + +@jit +def func(d): + return d[0] + 1 + + +d = {0: 1, 1: NetworkType.ANN.value} + +print(func(d)) diff --git a/examples/rnn_forward_test.py b/examples/rnn_forward_test.py deleted file mode 100644 index 0d33f77..0000000 --- a/examples/rnn_forward_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import jax -import jax.numpy as jnp -from jax.tree_util import register_pytree_node_class - - -@register_pytree_node_class -class Genome: - def __init__(self, nodes, conns): - self.nodes = nodes - self.conns = conns - - def update_nodes(self, nodes): - return Genome(nodes, self.conns) - - def update_conns(self, conns): - return Genome(self.nodes, conns) - - def tree_flatten(self): - children = self.nodes, self.conns - aux_data = None - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - - def __repr__(self): - return f"Genome ({self.nodes}, \n\t{self.conns})" - - @jax.jit - def add_node(self, a: int): - nodes = self.nodes.at[0, :].set(a) - return self.update_nodes(nodes) - - -nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]]) -g = Genome(nodes, conns) -print(g) - -g = g.add_node(1) -print(g) - -g = jax.jit(g.add_node)(2) -print(g) diff --git a/examples/xor.ini b/examples/xor.ini deleted file mode 100644 index 9752677..0000000 --- a/examples/xor.ini +++ /dev/null @@ -1,12 +0,0 @@ -[basic] -activate_times = 5 -fitness_threshold = 4 - -[population] -pop_size = 1000 - -[neat] -network_type = "recurrent" -num_inputs = 4 -num_outputs = 1 - diff --git a/examples/xor.py b/examples/xor.py index 73c7228..a3cd409 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,10 +1,10 @@ import jax import numpy as np +from config import Config, BasicConfig from pipeline import Pipeline -from config import Configer -from algorithm import NEAT -from algorithm.neat import RecurrentGene +from algorithm.neat.gene import NormalGene, NormalGeneConfig +from algorithm.neat.neat import NEAT xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -21,13 +21,11 @@ def evaluate(forward_func): return fitnesses -def main(): - config = Configer.load_config("xor.ini") - algorithm = NEAT(config, RecurrentGene) - pipeline = Pipeline(config, algorithm) - best = pipeline.auto_run(evaluate) - print(best) - - if __name__ == '__main__': - main() + config = Config( + basic=BasicConfig(fitness_target=4), + gene=NormalGeneConfig() + ) + algorithm = NEAT(config, NormalGene) + pipeline = Pipeline(config, algorithm) + pipeline.auto_run(evaluate) diff --git a/examples/xor_hyperneat.py b/examples/xor_hyperneat.py deleted file mode 100644 index d0d70d6..0000000 --- a/examples/xor_hyperneat.py +++ /dev/null @@ -1,33 +0,0 @@ -import jax -import numpy as np - -from pipeline import Pipeline -from config import Configer -from algorithm import NEAT, HyperNEAT -from algorithm.neat import RecurrentGene -from algorithm.hyperneat import BaseSubstrate - -xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) -xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) - - -def evaluate(forward_func): - """ - :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) - :return: - """ - outs = forward_func(xor_inputs) - outs = jax.device_get(outs) - fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) - return fitnesses - - -def main(): - config = Configer.load_config("xor.ini") - algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate) - pipeline = Pipeline(config, algorithm) - pipeline.auto_run(evaluate) - - -if __name__ == '__main__': - main() diff --git a/pipeline.py b/pipeline.py index da57566..2456534 100644 --- a/pipeline.py +++ b/pipeline.py @@ -5,7 +5,8 @@ import jax from jax import vmap, jit import numpy as np -from algorithm import Algorithm +from config import Config +from core import Algorithm, Genome class Pipeline: @@ -13,11 +14,11 @@ class Pipeline: Neat algorithm pipeline. """ - def __init__(self, config, algorithm: Algorithm): + def __init__(self, config: Config, algorithm: Algorithm): self.config = config self.algorithm = algorithm - randkey = jax.random.PRNGKey(config['random_seed']) + randkey = jax.random.PRNGKey(config.basic.seed) self.state = algorithm.setup(randkey) self.best_genome = None @@ -29,18 +30,18 @@ class Pipeline: self.forward_func = jit(self.algorithm.forward) self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) - self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0))) + self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0))) self.tell_func = jit(self.algorithm.tell) def ask(self): - pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns) + pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes) return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) def tell(self, fitness): self.state = self.tell_func(self.state, fitness) def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): - for _ in range(self.config['generation_limit']): + for _ in range(self.config.basic.generation_limit): forward_func = self.ask() fitnesses = fitness_func(forward_func) @@ -52,7 +53,7 @@ class Pipeline: assert callable(analysis), f"What the fuck you passed in? A {analysis}?" analysis(fitnesses) - if max(fitnesses) >= self.config['fitness_threshold']: + if max(fitnesses) >= self.config.basic.fitness_target: print("Fitness limit reached!") return self.best_genome @@ -70,11 +71,11 @@ class Pipeline: max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] - self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx]) + self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx]) - member_count = jax.device_get(self.state.species_info[:, 3]) + member_count = jax.device_get(self.state.member_count) species_sizes = [int(i) for i in member_count if i > 0] print(f"Generation: {self.state.generation}", f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/unit/__init__.py b/test/unit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/unit/test_cartesian_product.py b/test/unit/test_cartesian_product.py deleted file mode 100644 index 488eea0..0000000 --- a/test/unit/test_cartesian_product.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np - -from algorithm.hyperneat.substrate.tools import cartesian_product - - -def test01(): - keys1 = np.array([1, 2, 3]) - keys2 = np.array([4, 5, 6, 7]) - - coors1 = np.array([ - [1, 1, 1], - [2, 2, 2], - [3, 3, 3] - ]) - - coors2 = np.array([ - [4, 4, 4], - [5, 5, 5], - [6, 6, 6], - [7, 7, 7] - ]) - - target_coors = np.array([ - [1, 1, 1, 4, 4, 4], - [1, 1, 1, 5, 5, 5], - [1, 1, 1, 6, 6, 6], - [1, 1, 1, 7, 7, 7], - [2, 2, 2, 4, 4, 4], - [2, 2, 2, 5, 5, 5], - [2, 2, 2, 6, 6, 6], - [2, 2, 2, 7, 7, 7], - [3, 3, 3, 4, 4, 4], - [3, 3, 3, 5, 5, 5], - [3, 3, 3, 6, 6, 6], - [3, 3, 3, 7, 7, 7] - ]) - - target_keys = np.array([ - [1, 4], - [1, 5], - [1, 6], - [1, 7], - [2, 4], - [2, 5], - [2, 6], - [2, 7], - [3, 4], - [3, 5], - [3, 6], - [3, 7] - ]) - - new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2) - - assert np.array_equal(new_coors, target_coors) - assert np.array_equal(correspond_keys, target_keys) diff --git a/test/unit/test_graphs.py b/test/unit/test_graphs.py deleted file mode 100644 index 5721d6e..0000000 --- a/test/unit/test_graphs.py +++ /dev/null @@ -1,32 +0,0 @@ -import jax.numpy as jnp - -from algorithm.neat.genome.graph import topological_sort, check_cycles -from algorithm.utils import I_INT - -nodes = jnp.array([ - [0], - [1], - [2], - [3], - [jnp.nan] -]) - -# {(0, 2), (1, 2), (1, 3), (2, 3)} -conns = jnp.array([ - [0, 0, 1, 0, 0], - [0, 0, 1, 1, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] -]) - - -def test_topological_sort(): - assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT])) - - -def test_check_cycles(): - assert check_cycles(nodes, conns, 3, 2) - assert ~check_cycles(nodes, conns, 2, 3) - assert ~check_cycles(nodes, conns, 0, 3) - assert ~check_cycles(nodes, conns, 1, 0) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py deleted file mode 100644 index 81cb6c5..0000000 --- a/test/unit/test_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -import jax.numpy as jnp -from algorithm.utils import unflatten_connections - - -def test_unflatten(): - nodes = jnp.array([ - [0, 0, 0, 0], - [1, 1, 1, 1], - [2, 2, 2, 2], - [3, 3, 3, 3], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan] - ]) - - conns = jnp.array([ - [0, 1, True, 0.1, 0.11], - [0, 2, False, 0.2, 0.22], - [1, 2, True, 0.3, 0.33], - [1, 3, False, 0.4, 0.44], - ]) - - res = unflatten_connections(nodes, conns) - - assert jnp.all(res[:, 0, 1] == jnp.array([True, 0.1, 0.11])) - assert jnp.all(res[:, 0, 2] == jnp.array([False, 0.2, 0.22])) - assert jnp.all(res[:, 1, 2] == jnp.array([True, 0.3, 0.33])) - assert jnp.all(res[:, 1, 3] == jnp.array([False, 0.4, 0.44])) - - # Create a mask that excludes the indices we've already checked - mask = jnp.ones(res.shape, dtype=bool) - mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False) - - # Ensure all other places are jnp.nan - assert jnp.all(jnp.isnan(res[mask])) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..af946a4 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,4 @@ +from .activation import Activation +from .aggregation import Aggregation +from .tools import * +from .graph import * \ No newline at end of file diff --git a/algorithm/neat/gene/activation.py b/utils/activation.py similarity index 99% rename from algorithm/neat/gene/activation.py rename to utils/activation.py index ccbf655..6fdfaa1 100644 --- a/algorithm/neat/gene/activation.py +++ b/utils/activation.py @@ -88,6 +88,7 @@ class Activation: def cube_act(z): return z ** 3 + Activation.name2func = { 'sigmoid': Activation.sigmoid_act, 'tanh': Activation.tanh_act, @@ -107,4 +108,4 @@ Activation.name2func = { 'hat': Activation.hat_act, 'square': Activation.square_act, 'cube': Activation.cube_act, -} +} \ No newline at end of file diff --git a/algorithm/neat/gene/aggregation.py b/utils/aggregation.py similarity index 99% rename from algorithm/neat/gene/aggregation.py rename to utils/aggregation.py index be85ca4..6868b6b 100644 --- a/algorithm/neat/gene/aggregation.py +++ b/utils/aggregation.py @@ -60,4 +60,4 @@ Aggregation.name2func = { 'maxabs': Aggregation.maxabs_agg, 'median': Aggregation.median_agg, 'mean': Aggregation.mean_agg, -} \ No newline at end of file +} diff --git a/algorithm/neat/genome/graph.py b/utils/graph.py similarity index 95% rename from algorithm/neat/genome/graph.py rename to utils/graph.py index 72da89f..ef4eb19 100644 --- a/algorithm/neat/genome/graph.py +++ b/utils/graph.py @@ -6,13 +6,14 @@ Only used in feed-forward networks. import jax from jax import jit, Array, numpy as jnp -from algorithm.utils import fetch_first, I_INT +from .tools import fetch_first, I_INT @jit def topological_sort(nodes: Array, conns: Array) -> Array: """ a jit-able version of topological_sort! + conns: Array[N, N] """ in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0)) @@ -64,4 +65,4 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: return visited_, new_visited_ _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) - return visited[from_idx] + return visited[from_idx] \ No newline at end of file diff --git a/algorithm/utils.py b/utils/tools.py similarity index 92% rename from algorithm/utils.py rename to utils/tools.py index a64735c..3f22490 100644 --- a/algorithm/utils.py +++ b/utils/tools.py @@ -9,12 +9,9 @@ EMPTY_NODE = np.full((1, 5), jnp.nan) EMPTY_CON = np.full((1, 4), jnp.nan) -@jit -def unflatten_connections(nodes: Array, conns: Array): +def unflatten_conns(nodes, conns): """ transform the (C, CL) connections to (CL-2, N, N) - :param nodes: (N, NL) - :param cons: (C, CL) :return: """ N = nodes.shape[0] @@ -69,4 +66,4 @@ def rank_elements(array, reverse=False): """ if not reverse: array = -array - return jnp.argsort(jnp.argsort(array)) + return jnp.argsort(jnp.argsort(array)) \ No newline at end of file