diff --git a/algorithm/__init__.py b/algorithm/__init__.py index e69de29..68e966c 100644 --- a/algorithm/__init__.py +++ b/algorithm/__init__.py @@ -0,0 +1,2 @@ +from .neat import * +from .hyper_neat import * \ No newline at end of file diff --git a/algorithm/hyper_neat/__init__.py b/algorithm/hyper_neat/__init__.py new file mode 100644 index 0000000..4227bc4 --- /dev/null +++ b/algorithm/hyper_neat/__init__.py @@ -0,0 +1,2 @@ +from .hyper_neat import HyperNEAT +from .substrate import NormalSubstrate, NormalSubstrateConfig \ No newline at end of file diff --git a/algorithm/hyper_neat/hyper_neat.py b/algorithm/hyper_neat/hyper_neat.py new file mode 100644 index 0000000..6008489 --- /dev/null +++ b/algorithm/hyper_neat/hyper_neat.py @@ -0,0 +1,122 @@ +from typing import Type + +import jax +from jax import numpy as jnp, Array, vmap +import numpy as np + +from config import Config, HyperNeatConfig +from core import Algorithm, Substrate, State, Genome +from utils import Activation, Aggregation +from algorithm.neat import NEAT +from .substrate import analysis_substrate + +class HyperNEAT(Algorithm): + + def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]): + self.config = config + self.neat = neat + self.substrate = substrate + + self.forward_func = None + + def setup(self, randkey, state=State()): + neat_key, randkey = jax.random.split(randkey) + state = state.update( + below_threshold=self.config.hyper_neat.below_threshold, + max_weight=self.config.hyper_neat.max_weight, + ) + state = self.neat.setup(neat_key, state) + state = self.substrate.setup(self.config.substrate, state) + + assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias + assert self.config.hyper_neat.outputs == state.output_coors.shape[0] + + h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state) + h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis] + h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32) + h_conns[:, 0:2] = correspond_keys + + state = state.update( + h_input_idx=h_input_idx, + h_output_idx=h_output_idx, + h_hidden_idx=h_hidden_idx, + h_nodes=h_nodes, + h_conns=h_conns, + query_coors=query_coors, + ) + + self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state) + + return state + def ask(self, state: State): + return state.pop_genomes + + def tell(self, state: State, fitness): + return self.neat.tell(state, fitness) + + def forward(self, inputs: Array, transformed: Array): + return self.forward_func(inputs, transformed) + + def forward_transform(self, state: State, genome: Genome): + t = self.neat.forward_transform(state, genome) + query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t) + + # mute the connection with weight below threshold + query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res) + + # make query res in range [-max_weight, max_weight] + query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res) + query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res) + query_res = query_res / (1 - state.below_threshold) * state.max_weight + + h_conns = state.h_conns.at[:, 2:].set(query_res) + return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns)) + + +class HyperNEATGene: + node_attrs = [] # no node attributes + conn_attrs = ['weight'] + + @staticmethod + def forward_transform(genome: Genome): + N = genome.nodes.shape[0] + u_conns = jnp.zeros((N, N), dtype=jnp.float32) + + in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32) + out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32) + weights = genome.conns[:, 2] + + u_conns = u_conns.at[in_keys, out_keys].set(weights) + return genome.nodes, u_conns + + @staticmethod + def create_forward(config: HyperNeatConfig, state: State): + + act = Activation.name2func[config.activation] + agg = Aggregation.name2func[config.aggregation] + + batch_act, batch_agg = jax.vmap(act), jax.vmap(agg) + + def forward(inputs, transform): + + inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0) + nodes, weights = transform + + input_idx = state.h_input_idx + output_idx = state.h_output_idx + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + def body_func(i, values): + values = values.at[input_idx].set(inputs_with_bias) + nodes_ins = values * weights.T + values = batch_agg(nodes_ins) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(values) # z = act(z) + return values + + vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) + return vals[output_idx] + + return forward \ No newline at end of file diff --git a/algorithm/hyper_neat/substrate/__init__.py b/algorithm/hyper_neat/substrate/__init__.py new file mode 100644 index 0000000..a0378ba --- /dev/null +++ b/algorithm/hyper_neat/substrate/__init__.py @@ -0,0 +1,2 @@ +from .normal import NormalSubstrate, NormalSubstrateConfig +from .tools import analysis_substrate diff --git a/algorithm/hyper_neat/substrate/normal.py b/algorithm/hyper_neat/substrate/normal.py new file mode 100644 index 0000000..e16eedd --- /dev/null +++ b/algorithm/hyper_neat/substrate/normal.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import Tuple + +import numpy as np + +from core import Substrate, State +from config import SubstrateConfig + + +@dataclass(frozen=True) +class NormalSubstrateConfig(SubstrateConfig): + input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1)) + hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0)) + output_coors: Tuple[Tuple[float]] = ((0, 1), ) + + +class NormalSubstrate(Substrate): + + @staticmethod + def setup(config: NormalSubstrateConfig, state: State = State()): + return state.update( + input_coors=np.asarray(config.input_coors, dtype=np.float32), + output_coors=np.asarray(config.output_coors, dtype=np.float32), + hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32), + ) diff --git a/algorithm/hyper_neat/substrate/tools.py b/algorithm/hyper_neat/substrate/tools.py new file mode 100644 index 0000000..21413be --- /dev/null +++ b/algorithm/hyper_neat/substrate/tools.py @@ -0,0 +1,50 @@ +from typing import Type + +import numpy as np + +def analysis_substrate(state): + cd = state.input_coors.shape[1] # coordinate dimensions + si = state.input_coors.shape[0] # input coordinate size + so = state.output_coors.shape[0] # output coordinate size + sh = state.hidden_coors.shape[0] # hidden coordinate size + + input_idx = np.arange(si) + output_idx = np.arange(si, si + so) + hidden_idx = np.arange(si + so, si + so + sh) + + total_conns = si * sh + sh * sh + sh * so + query_coors = np.zeros((total_conns, cd * 2)) + correspond_keys = np.zeros((total_conns, 2)) + + # connect input to hidden + aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors) + query_coors[0: si * sh, :] = aux_coors + correspond_keys[0: si * sh, :] = aux_keys + + # connect hidden to hidden + aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors) + query_coors[si * sh: si * sh + sh * sh, :] = aux_coors + correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys + + # connect hidden to output + aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors) + query_coors[si * sh + sh * sh:, :] = aux_coors + correspond_keys[si * sh + sh * sh:, :] = aux_keys + + return input_idx, output_idx, hidden_idx, query_coors, correspond_keys + + +def cartesian_product(keys1, keys2, coors1, coors2): + len1 = keys1.shape[0] + len2 = keys2.shape[0] + + repeated_coors1 = np.repeat(coors1, len2, axis=0) + repeated_keys1 = np.repeat(keys1, len2) + + tiled_coors2 = np.tile(coors2, (len1, 1)) + tiled_keys2 = np.tile(keys2, len1) + + new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) + correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) + + return new_coors, correspond_keys \ No newline at end of file diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index e69de29..d6bb53c 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -0,0 +1,2 @@ +from .neat import NEAT +from .gene import * diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index 02af6ce..9c3c16d 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -1 +1,2 @@ from .normal import NormalGene, NormalGeneConfig +from .recurrent import RecurrentGene, RecurrentGeneConfig diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py new file mode 100644 index 0000000..1d3942b --- /dev/null +++ b/algorithm/neat/gene/recurrent.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass + +import jax +from jax import Array, numpy as jnp, vmap + +from .normal import NormalGene, NormalGeneConfig +from core import State, Genome +from utils import Activation, Aggregation, unflatten_conns + + +@dataclass(frozen=True) +class RecurrentGeneConfig(NormalGeneConfig): + activate_times: int = 10 + + def __post_init__(self): + super().__post_init__() + assert self.activate_times > 0 + + +class RecurrentGene(NormalGene): + + @staticmethod + def forward_transform(state: State, genome: Genome): + u_conns = unflatten_conns(genome.nodes, genome.conns) + + # remove un-enable connections and remove enable attr + conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + + return genome.nodes, u_conns + + @staticmethod + def create_forward(state: State, config: RecurrentGeneConfig): + activation_funcs = [Activation.name2func[name] for name in config.activation_options] + aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options] + + def act(idx, z): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + res = jax.lax.switch(idx, activation_funcs, z) + return res + + def agg(idx, z): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + + def all_nan(): + return 0. + + def not_all_nan(): + return jax.lax.switch(idx, aggregation_funcs, z) + + return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) + + batch_act, batch_agg = vmap(act), vmap(agg) + + def forward(inputs, transform) -> Array: + nodes, cons = transform + + input_idx = state.input_idx + output_idx = state.output_idx + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + weights = cons[0, :] + + def body_func(i, values): + values = values.at[input_idx].set(inputs) + nodes_ins = values * weights.T + values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(nodes[:, 3], values) # z = act(z) + return values + + vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals) + return vals[output_idx] + + return forward diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 4e5ef1f..e928287 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -7,7 +7,7 @@ import numpy as np from config import Config from core import Algorithm, State, Gene, Genome from .ga import crossover, create_mutate -from .species import update_species, create_speciate +from .species import SpeciesInfo, update_species, create_speciate class NEAT(Algorithm): @@ -22,9 +22,9 @@ class NEAT(Algorithm): def setup(self, randkey, state: State = State()): """initialize the state of the algorithm""" - input_idx = np.arange(self.config.basic.num_inputs) - output_idx = np.arange(self.config.basic.num_inputs, - self.config.basic.num_inputs + self.config.basic.num_outputs) + input_idx = np.arange(self.config.neat.inputs) + output_idx = np.arange(self.config.neat.inputs, + self.config.neat.inputs + self.config.neat.outputs) state = state.update( P=self.config.basic.pop_size, @@ -49,22 +49,13 @@ class NEAT(Algorithm): state = self.gene_type.setup(self.config.gene, state) pop_genomes = self._initialize_genomes(state) - species_keys = np.full((state.S,), np.nan, dtype=np.float32) - best_fitness = np.full((state.S,), np.nan, dtype=np.float32) - last_improved = np.full((state.S,), np.nan, dtype=np.float32) - member_count = np.full((state.S,), np.nan, dtype=np.float32) + species_info = SpeciesInfo.initialize(state) idx2species = jnp.zeros(state.P, dtype=jnp.float32) - species_keys[0] = 0 - best_fitness[0] = -np.inf - last_improved[0] = 0 - member_count[0] = state.P - center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32) center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32) - center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :]) - center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :]) - center_genomes = vmap(Genome)(center_nodes, center_conns) + center_genomes = Genome(center_nodes, center_conns) + center_genomes = center_genomes.set(0, pop_genomes[0]) generation = 0 next_node_key = max(*state.input_idx, *state.output_idx) + 2 @@ -73,10 +64,7 @@ class NEAT(Algorithm): state = state.update( randkey=randkey, pop_genomes=pop_genomes, - species_keys=species_keys, - best_fitness=best_fitness, - last_improved=last_improved, - member_count=member_count, + species_info=species_info, idx2species=idx2species, center_genomes=center_genomes, @@ -135,7 +123,7 @@ class NEAT(Algorithm): pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) pop_conns = np.tile(o_conns, (state.P, 1, 1)) - return vmap(Genome)(pop_nodes, pop_conns) + return Genome(pop_nodes, pop_conns) def _create_tell(self): mutate = create_mutate(self.config.neat, self.gene_type) diff --git a/algorithm/neat/species/__init__.py b/algorithm/neat/species/__init__.py index d5d058d..8717c2e 100644 --- a/algorithm/neat/species/__init__.py +++ b/algorithm/neat/species/__init__.py @@ -1 +1,2 @@ from .operations import update_species, create_speciate +from .species_info import SpeciesInfo diff --git a/algorithm/neat/species/operations.py b/algorithm/neat/species/operations.py index 7921016..ce3401a 100644 --- a/algorithm/neat/species/operations.py +++ b/algorithm/neat/species/operations.py @@ -6,6 +6,7 @@ from jax import numpy as jnp, vmap from core import Gene, Genome from utils import rank_elements, fetch_first from .distance import create_distance +from .species_info import SpeciesInfo def update_species(state, randkey, fitness): @@ -18,15 +19,9 @@ def update_species(state, randkey, fitness): # sort species_info by their fitness. (push nan to the end) sort_indices = jnp.argsort(species_fitness)[::-1] - center_nodes = state.center_genomes.nodes[sort_indices] - center_conns = state.center_genomes.conns[sort_indices] - state = state.update( - species_keys=state.species_keys[sort_indices], - best_fitness=state.best_fitness[sort_indices], - last_improved=state.last_improved[sort_indices], - member_count=state.member_count[sort_indices], - center_genomes=Genome(center_nodes, center_conns), + species_info=state.species_info[sort_indices], + center_genomes=state.center_genomes[sort_indices], ) # decide the number of members of each species by their fitness @@ -45,11 +40,11 @@ def update_species_fitness(state, fitness): """ def aux_func(idx): - s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf) + s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf) f = jnp.max(s_fitness) return f - return vmap(aux_func)(jnp.arange(state.species_keys.shape[0])) + return vmap(aux_func)(jnp.arange(state.species_info.size())) def stagnation(state, species_fitness): @@ -61,7 +56,7 @@ def stagnation(state, species_fitness): def aux_func(idx): s_fitness = species_fitness[idx] - sk, bf, li = state.species_keys[idx], state.best_fitness[idx], state.last_improved[idx] + sk, bf, li, _ = state.species_info.get(idx) st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation) li = jnp.where(s_fitness > bf, state.generation, li) bf = jnp.where(s_fitness > bf, s_fitness, bf) @@ -78,18 +73,19 @@ def stagnation(state, species_fitness): species_keys = jnp.where(spe_st, jnp.nan, species_keys) best_fitness = jnp.where(spe_st, jnp.nan, best_fitness) last_improved = jnp.where(spe_st, jnp.nan, last_improved) - member_count = jnp.where(spe_st, jnp.nan, state.member_count) + member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count) + species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness) + species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count) + + # TODO: Simplify the coded center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes) center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns) state = state.update( - species_keys=species_keys, - best_fitness=best_fitness, - last_improved=last_improved, - member_count=member_count, - center_genomes=state.center_genomes.update(center_nodes, center_conns) + species_info=species_info, + center_genomes=Genome(center_nodes, center_conns) ) return state, species_fitness @@ -103,18 +99,20 @@ def cal_spawn_numbers(state): e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] """ - is_species_valid = ~jnp.isnan(state.species_keys) + species_keys = state.species_info.species_keys + + is_species_valid = ~jnp.isnan(species_keys) valid_species_num = jnp.sum(is_species_valid) denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 - rank_score = valid_species_num - jnp.arange(state.species_keys.shape[0]) # obtain [3, 2, 1] + rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1] spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member # Avoid too much variation of numbers in a species - previous_size = state.member_count + previous_size = state.species_info.member_count spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate # jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number) spawn_number = spawn_number.astype(jnp.int32) @@ -127,14 +125,14 @@ def cal_spawn_numbers(state): def create_crossover_pair(state, randkey, spawn_number, fitness): - species_size = state.species_keys.shape[0] + species_size = state.species_info.size() pop_size = fitness.shape[0] s_idx = jnp.arange(species_size) p_idx = jnp.arange(pop_size) # def aux_func(key, idx): def aux_func(key, idx): - members = state.idx2species == state.species_keys[idx] + members = state.idx2species == state.species_info.species_keys[idx] members_num = jnp.sum(members) members_fitness = jnp.where(members, fitness, -jnp.inf) @@ -176,7 +174,7 @@ def create_speciate(gene_type: Type[Gene]): distance = create_distance(gene_type) def speciate(state): - pop_size, species_size = state.idx2species.shape[0], state.species_keys.shape[0] + pop_size, species_size = state.idx2species.shape[0], state.species_info.size() # prepare distance functions o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population @@ -191,25 +189,23 @@ def create_speciate(gene_type: Type[Gene]): def cond_func(carry): i, i2s, cgs, o2c = carry - return (i < species_size) & (~jnp.isnan(state.species_keys[i])) # current species is existing + return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing def body_func(carry): i, i2s, cgs, o2c = carry - distances = o2p_distance_func(state, Genome(cgs.nodes[i], cgs.conns[i]), state.pop_genomes) + distances = o2p_distance_func(state, cgs[i], state.pop_genomes) # find the closest one closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - # jax.debug.print("closest_idx: {}", closest_idx) - i2s = i2s.at[closest_idx].set(state.species_keys[i]) - cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[closest_idx]) - cc = cgs.conns.at[i].set(state.pop_genomes.conns[closest_idx]) + i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) + cgs = cgs.set(i, state.pop_genomes[closest_idx]) # the genome with closest_idx will become the new center, thus its distance to center is 0. o2c = o2c.at[closest_idx].set(0) - return i + 1, i2s, Genome(cn, cc), o2c + return i + 1, i2s, cgs, o2c _, idx2species, center_genomes, o2c_distances = \ jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) @@ -247,15 +243,13 @@ def create_speciate(gene_type: Type[Gene]): idx = fetch_first(jnp.isnan(i2s)) # assign it to the new species - # [key, best score, last update generation, members_count] + # [key, best score, last update generation, member_count] sk = sk.at[i].set(nsk) i2s = i2s.at[idx].set(nsk) o2c = o2c.at[idx].set(0) # update center genomes - cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[idx]) - cc = cgs.conns.at[i].set(state.pop_genomes.conns[idx]) - cgs = Genome(cn, cc) + cgs = cgs.set(i, state.pop_genomes[idx]) i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) @@ -273,8 +267,7 @@ def create_speciate(gene_type: Type[Gene]): def speciate_by_threshold(i, i2s, cgs, sk, o2c): # distance between such center genome and ppo genomes - center = Genome(cgs.nodes[i], cgs.conns[i]) - o2p_distance = o2p_distance_func(state, center, state.pop_genomes) + o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes) close_enough_mask = o2p_distance < state.compatibility_threshold # when a genome is not assigned or the distance between its current center is bigger than this center @@ -294,32 +287,31 @@ def create_speciate(gene_type: Type[Gene]): _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop( cond_func, body_func, - (0, state.idx2species, state.center_genomes, state.species_keys, o2c_distances, state.next_species_key) + (0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key) ) + # if there are still some pop genomes not assigned to any species, add them to the last genome # this condition can only happen when the number of species is reached species upper bounds idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) # complete info of species which is created in this generation - new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) - best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) - last_improved = jnp.where(new_created_mask, state.generation, state.last_improved) + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness) + best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness) + last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved) # update members count def count_members(idx): key = species_keys[idx] - count = jnp.sum(idx2species == key) + count = jnp.sum(idx2species == key, dtype=jnp.float32) count = jnp.where(jnp.isnan(key), jnp.nan, count) + return count member_count = vmap(count_members)(jnp.arange(species_size)) return state.update( - species_keys=species_keys, - best_fitness=best_fitness, - last_improved=last_improved, - members_count=member_count, + species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count), idx2species=idx2species, center_genomes=center_genomes, next_species_key=next_species_key diff --git a/algorithm/neat/species/species_info.py b/algorithm/neat/species/species_info.py new file mode 100644 index 0000000..d2e4788 --- /dev/null +++ b/algorithm/neat/species/species_info.py @@ -0,0 +1,55 @@ +from jax.tree_util import register_pytree_node_class +import numpy as np +import jax.numpy as jnp + +@register_pytree_node_class +class SpeciesInfo: + + def __init__(self, species_keys, best_fitness, last_improved, member_count): + self.species_keys = species_keys + self.best_fitness = best_fitness + self.last_improved = last_improved + self.member_count = member_count + + @classmethod + def initialize(cls, state): + species_keys = np.full((state.S,), np.nan, dtype=np.float32) + best_fitness = np.full((state.S,), np.nan, dtype=np.float32) + last_improved = np.full((state.S,), np.nan, dtype=np.float32) + member_count = np.full((state.S,), np.nan, dtype=np.float32) + + species_keys[0] = 0 + best_fitness[0] = -np.inf + last_improved[0] = 0 + member_count[0] = state.P + + return cls(species_keys, best_fitness, last_improved, member_count) + + def __getitem__(self, i): + return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]) + + def get(self, i): + return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i] + + def set(self, idx, value): + species_keys = self.species_keys.at[idx].set(value[0]) + best_fitness = self.best_fitness.at[idx].set(value[1]) + last_improved = self.last_improved.at[idx].set(value[2]) + member_count = self.member_count.at[idx].set(value[3]) + return SpeciesInfo(species_keys, best_fitness, last_improved, member_count) + + def remove(self, idx): + return self.set(idx, jnp.array([jnp.nan] * 4)) + + def size(self): + return self.species_keys.shape[0] + + + def tree_flatten(self): + children = self.species_keys, self.best_fitness, self.last_improved, self.member_count + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) diff --git a/config/config.py b/config/config.py index bc351f9..ba54ff7 100644 --- a/config/config.py +++ b/config/config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Union @dataclass(frozen=True) @@ -7,34 +6,31 @@ class BasicConfig: seed: int = 42 fitness_target: float = 1 generation_limit: int = 1000 - num_inputs: int = 2 - num_outputs: int = 1 pop_size: int = 100 def __post_init__(self): - assert self.num_inputs > 0, "the inputs number of the problem must be greater than 0" - assert self.num_outputs > 0, "the outputs number of the problem must be greater than 0" assert self.pop_size > 0, "the population size must be greater than 0" @dataclass(frozen=True) class NeatConfig: network_type: str = "feedforward" - activate_times: Union[int, None] = None # None means the network is feedforward - maximum_nodes: int = 100 - maximum_conns: int = 50 + inputs: int = 2 + outputs: int = 1 + maximum_nodes: int = 50 + maximum_conns: int = 100 maximum_species: int = 10 # genome config compatibility_disjoint: float = 1 compatibility_weight: float = 0.5 conn_add: float = 0.4 - conn_delete: float = 0.4 + conn_delete: float = 0 node_add: float = 0.2 - node_delete: float = 0.2 + node_delete: float = 0 # species config - compatibility_threshold: float = 3.0 + compatibility_threshold: float = 3.5 species_elitism: int = 2 max_stagnation: int = 15 genome_elitism: int = 2 @@ -44,11 +40,9 @@ class NeatConfig: def __post_init__(self): assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent" - if self.network_type == "feedforward": - assert self.activate_times is None, "the activate times of feedforward network must be None" - else: - assert isinstance(self.activate_times, int), "the activate times of recurrent network must be int" - assert self.activate_times > 0, "the activate times of recurrent network must be greater than 0" + + assert self.inputs > 0, "the inputs number of neat must be greater than 0" + assert self.outputs > 0, "the outputs number of neat must be greater than 0" assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0" assert self.maximum_conns > 0, "the maximum connections must be greater than 0" @@ -56,10 +50,10 @@ class NeatConfig: assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0" assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0" - assert self.conn_add > 0, "the connection add probability must be greater than 0" - assert self.conn_delete > 0, "the connection delete probability must be greater than 0" - assert self.node_add > 0, "the node add probability must be greater than 0" - assert self.node_delete > 0, "the node delete probability must be greater than 0" + assert self.conn_add >= 0, "the connection add probability must be greater than 0" + assert self.conn_delete >= 0, "the connection delete probability must be greater than 0" + assert self.node_add >= 0, "the node add probability must be greater than 0" + assert self.node_delete >= 0, "the node delete probability must be greater than 0" assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0" assert self.species_elitism > 0, "the species elitism must be greater than 0" @@ -77,18 +71,21 @@ class HyperNeatConfig: activation: str = "sigmoid" aggregation: str = "sum" activate_times: int = 5 + inputs: int = 2 + outputs: int = 1 def __post_init__(self): assert self.below_threshold > 0, "the below threshold must be greater than 0" assert self.max_weight > 0, "the max weight must be greater than 0" assert self.activate_times > 0, "the activate times must be greater than 0" + assert self.inputs > 0, "the inputs number of hyper neat must be greater than 0" + assert self.outputs > 0, "the outputs number of hyper neat must be greater than 0" @dataclass(frozen=True) class GeneConfig: pass - @dataclass(frozen=True) class SubstrateConfig: pass diff --git a/core/__init__.py b/core/__init__.py index 1bf1c8c..8e16999 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -2,4 +2,4 @@ from .algorithm import Algorithm from .state import State from .genome import Genome from .gene import Gene - +from .substrate import Substrate diff --git a/core/gene.py b/core/gene.py index 03a9bfe..b1b2704 100644 --- a/core/gene.py +++ b/core/gene.py @@ -40,7 +40,7 @@ class Gene: @staticmethod def forward_transform(state: State, genome: Genome): return jnp.zeros(0) # transformed + @staticmethod def create_forward(state: State, config: GeneConfig): return lambda *args: args # forward function - diff --git a/core/genome.py b/core/genome.py index de5853b..75d3267 100644 --- a/core/genome.py +++ b/core/genome.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from jax.tree_util import register_pytree_node_class from jax import numpy as jnp @@ -11,6 +13,15 @@ class Genome: self.nodes = nodes self.conns = conns + def __repr__(self): + return f"Genome(nodes={self.nodes}, conns={self.conns})" + + def __getitem__(self, idx): + return self.__class__(self.nodes[idx], self.conns[idx]) + + def set(self, idx, value: Genome): + return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns)) + def update(self, nodes, conns): return self.__class__(nodes, conns) @@ -73,5 +84,4 @@ class Genome: def tree_unflatten(cls, aux_data, children): return cls(*children) - def __repr__(self): - return f"Genome(nodes={self.nodes}, conns={self.conns})" + diff --git a/core/substrate.py b/core/substrate.py new file mode 100644 index 0000000..e9694d1 --- /dev/null +++ b/core/substrate.py @@ -0,0 +1,8 @@ +from config import SubstrateConfig + + +class Substrate: + + @staticmethod + def setup(state, config: SubstrateConfig): + return state diff --git a/examples/a.py b/examples/a.py index 5932ef6..94c7478 100644 --- a/examples/a.py +++ b/examples/a.py @@ -12,11 +12,10 @@ print(asdict(config)) pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3)) pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5)) -pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns) -pop_genomes2 = Genome(pop_nodes, pop_conns) +pop_genomes = Genome(pop_nodes, pop_conns) print(pop_genomes) -print(pop_genomes[0]) +print(pop_genomes[0: 20]) @jax.vmap def pop_cnts(genome): diff --git a/examples/b.py b/examples/b.py index 42cabf0..deb7808 100644 --- a/examples/b.py +++ b/examples/b.py @@ -15,5 +15,9 @@ def func(d): d = {0: 1, 1: NetworkType.ANN.value} +n = None + +print(n or d) +print(d) print(func(d)) diff --git a/examples/xor.py b/examples/xor.py index a3cd409..7b00f74 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -1,10 +1,9 @@ import jax import numpy as np -from config import Config, BasicConfig +from config import Config, BasicConfig, NeatConfig from pipeline import Pipeline -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from algorithm.neat.neat import NEAT +from algorithm import NEAT, NormalGene, NormalGeneConfig xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -23,8 +22,14 @@ def evaluate(forward_func): if __name__ == '__main__': config = Config( - basic=BasicConfig(fitness_target=4), - gene=NormalGeneConfig() + basic=BasicConfig( + fitness_target=3.99999, + pop_size=10000 + ), + neat=NeatConfig( + maximum_nodes=50, + maximum_conns=100, + ) ) algorithm = NEAT(config, NormalGene) pipeline = Pipeline(config, algorithm) diff --git a/examples/xor_hyperNEAT.py b/examples/xor_hyperNEAT.py new file mode 100644 index 0000000..ec642e9 --- /dev/null +++ b/examples/xor_hyperNEAT.py @@ -0,0 +1,49 @@ +import jax +import numpy as np + +from config import Config, BasicConfig, NeatConfig +from pipeline import Pipeline +from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig +from algorithm import HyperNEAT, NormalSubstrate, NormalSubstrateConfig + + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + fitness_target=3.99999, + pop_size=1000 + ), + neat=NeatConfig( + network_type="recurrent", + maximum_nodes=50, + maximum_conns=100, + inputs=4, + outputs=1 + + ), + gene=RecurrentGeneConfig( + activation_default="tanh", + activation_options=("tanh", ), + ), + substrate=NormalSubstrateConfig(), + ) + neat = NEAT(config, RecurrentGene) + hyperNEAT = HyperNEAT(config, neat, NormalSubstrate) + + pipeline = Pipeline(config, hyperNEAT) + pipeline.auto_run(evaluate) diff --git a/examples/xor_recurrent.py b/examples/xor_recurrent.py new file mode 100644 index 0000000..bfe6e20 --- /dev/null +++ b/examples/xor_recurrent.py @@ -0,0 +1,39 @@ +import jax +import numpy as np + +from config import Config, BasicConfig, NeatConfig +from pipeline import Pipeline +from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig + + +xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) +xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) + + +def evaluate(forward_func): + """ + :param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size) + :return: + """ + outs = forward_func(xor_inputs) + outs = jax.device_get(outs) + fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) + return fitnesses + + +if __name__ == '__main__': + config = Config( + basic=BasicConfig( + fitness_target=3.99999, + pop_size=10000 + ), + neat=NeatConfig( + network_type="recurrent", + maximum_nodes=50, + maximum_conns=100 + ), + gene=RecurrentGeneConfig() + ) + algorithm = NEAT(config, RecurrentGene) + pipeline = Pipeline(config, algorithm) + pipeline.auto_run(evaluate) diff --git a/pipeline.py b/pipeline.py index 2456534..75bc24d 100644 --- a/pipeline.py +++ b/pipeline.py @@ -11,7 +11,7 @@ from core import Algorithm, Genome class Pipeline: """ - Neat algorithm pipeline. + Simple pipeline. """ def __init__(self, config: Config, algorithm: Algorithm): @@ -38,7 +38,9 @@ class Pipeline: return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) def tell(self, fitness): - self.state = self.tell_func(self.state, fitness) + # self.state = self.tell_func(self.state, fitness) + new_state = self.tell_func(self.state, fitness) + self.state = new_state def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config.basic.generation_limit): @@ -73,9 +75,9 @@ class Pipeline: self.best_fitness = fitnesses[max_idx] self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx]) - member_count = jax.device_get(self.state.member_count) + member_count = jax.device_get(self.state.species_info.member_count) species_sizes = [int(i) for i in member_count if i > 0] print(f"Generation: {self.state.generation}", f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") \ No newline at end of file + f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") \ No newline at end of file