From 4a631f9464f429d7daffea3908aa56674c4eef61 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 11 Jul 2024 10:10:16 +0800 Subject: [PATCH] modify NEAT package; successfully run xor example --- Pipeline 20240711012327.pkl | 0 examples/func_fit/xor.py | 60 ++- network.svg | 120 ++--- tensorneat/algorithm/base.py | 18 +- tensorneat/algorithm/neat/neat.py | 164 +++++-- .../neat/{species/default.py => species.py} | 418 ++++++------------ tensorneat/algorithm/neat/species/__init__.py | 2 - tensorneat/algorithm/neat/species/base.py | 20 - tensorneat/genome/__init__.py | 2 + tensorneat/genome/gene/node/default.py | 16 +- .../genome/operations/distance/default.py | 2 +- tensorneat/genome/operations/mutation/base.py | 2 +- .../genome/operations/mutation/default.py | 42 +- tensorneat/pipeline.py | 56 +-- 14 files changed, 420 insertions(+), 502 deletions(-) create mode 100644 Pipeline 20240711012327.pkl rename tensorneat/algorithm/neat/{species/default.py => species.py} (52%) delete mode 100644 tensorneat/algorithm/neat/species/__init__.py delete mode 100644 tensorneat/algorithm/neat/species/base.py diff --git a/Pipeline 20240711012327.pkl b/Pipeline 20240711012327.pkl new file mode 100644 index 0000000..e69de29 diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index 1ab61b8..9894d63 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -1,43 +1,38 @@ -from pipeline import Pipeline -from algorithm.neat import * - -from problem.func_fit import XOR3d -from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, DefaultNodeGene, DefaultMutation +from tensorneat.problem.func_fit import XOR3d +from tensorneat.common import Act, Agg if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( - species=DefaultSpecies( - genome=DenseInitialize( - num_inputs=3, - num_outputs=1, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_default=Act.tanh, - # activation_options=(Act.tanh,), - activation_options=ACT_ALL, - aggregation_default=Agg.sum, - # aggregation_options=(Agg.sum,), - aggregation_options=AGG_ALL, - ), - output_transform=Act.standard_sigmoid, # the activation function for output node - mutation=DefaultMutation( - node_add=0.1, - conn_add=0.1, - node_delete=0, - conn_delete=0, - ), + pop_size=10000, + species_size=20, + compatibility_threshold=2, + survival_threshold=0.01, + genome=DefaultGenome( + num_inputs=3, + num_outputs=1, + init_hidden_layers=(), + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=Act.tanh, + aggregation_default=Agg.sum, + aggregation_options=Agg.sum, + ), + output_transform=Act.standard_sigmoid, # the activation function for output node + mutation=DefaultMutation( + node_add=0.1, + conn_add=0.1, + node_delete=0, + conn_delete=0, ), - pop_size=10000, - species_size=20, - compatibility_threshold=2, - survival_threshold=0.01, # magic ), ), problem=XOR3d(), - generation_limit=10000, - fitness_target=-1e-3, + generation_limit=500, + fitness_target=-1e-8, ) # initialize state @@ -47,4 +42,3 @@ if __name__ == "__main__": state, best = pipeline.auto_run(state) # show result pipeline.show(state, best) - pipeline.save(state=state) diff --git a/network.svg b/network.svg index d86c21f..b673fe3 100644 --- a/network.svg +++ b/network.svg @@ -6,7 +6,7 @@ - 2024-07-10T16:50:19.947855 + 2024-07-10T19:47:34.359228 image/svg+xml @@ -32,222 +32,222 @@ z +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p572566e0dc)" style="fill: #0000ff; stroke: #0000ff"/> - + diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index 557493f..b57dd17 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -14,13 +14,11 @@ class BaseAlgorithm(StatefulBaseClass): """transform the genome into a neural network""" raise NotImplementedError - def restore(self, state, transformed): - raise NotImplementedError - def forward(self, state, transformed, inputs): raise NotImplementedError - def update_by_batch(self, state, batch_input, transformed): + def show_details(self, state: State, fitness): + """Visualize the running details of the algorithm""" raise NotImplementedError @property @@ -30,15 +28,3 @@ class BaseAlgorithm(StatefulBaseClass): @property def num_outputs(self): raise NotImplementedError - - @property - def pop_size(self): - raise NotImplementedError - - def member_count(self, state: State): - # to analysis the species - raise NotImplementedError - - def generation(self, state: State): - # to analysis the algorithm - raise NotImplementedError diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index 32abd07..57826ca 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -1,40 +1,93 @@ -from tensorneat.common import State +import jax +from jax import vmap, numpy as jnp +import numpy as np + +from .species import SpeciesController from .. import BaseAlgorithm -from .species import * +from tensorneat.common import State +from tensorneat.genome import BaseGenome class NEAT(BaseAlgorithm): def __init__( self, - species: BaseSpecies, + genome: BaseGenome, + pop_size: int, + species_size: int = 10, + max_stagnation: int = 15, + species_elitism: int = 2, + spawn_number_change_rate: float = 0.5, + genome_elitism: int = 2, + survival_threshold: float = 0.2, + min_species_size: int = 1, + compatibility_threshold: float = 3.0, + species_fitness_func: callable = jnp.max, ): - self.species = species - self.genome = species.genome + self.genome = genome + self.pop_size = pop_size + self.species_controller = SpeciesController( + pop_size, + species_size, + max_stagnation, + species_elitism, + spawn_number_change_rate, + genome_elitism, + survival_threshold, + min_species_size, + compatibility_threshold, + species_fitness_func, + ) def setup(self, state=State()): - state = self.species.setup(state) + # setup state + state = self.genome.setup(state) + + k1, randkey = jax.random.split(state.randkey, 2) + + # initialize the population + initialize_keys = jax.random.split(k1, self.pop_size) + pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))( + state, initialize_keys + ) + + state = state.register( + pop_nodes=pop_nodes, + pop_conns=pop_conns, + generation=jnp.float32(0), + ) + + # initialize species state + state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0]) + + return state.update(randkey=randkey) + + def ask(self, state): + return state.pop_nodes, state.pop_conns + + def tell(self, state, fitness): + state = state.update(generation=state.generation + 1) + + # tell fitness to species controller + state, winner, loser, elite_mask = self.species_controller.update_species( + state, + fitness, + ) + + # create next population + state = self._create_next_generation(state, winner, loser, elite_mask) + + # speciate the next population + state = self.species_controller.speciate(state, self.genome.execute_distance) + return state - def ask(self, state: State): - return self.species.ask(state) - - def tell(self, state: State, fitness): - return self.species.tell(state, fitness) - def transform(self, state, individual): - """transform the genome into a neural network""" nodes, conns = individual return self.genome.transform(state, nodes, conns) - def restore(self, state, transformed): - return self.genome.restore(state, transformed) - def forward(self, state, transformed, inputs): return self.genome.forward(state, transformed, inputs) - def update_by_batch(self, state, batch_input, transformed): - return self.genome.update_by_batch(state, batch_input, transformed) - @property def num_inputs(self): return self.genome.num_inputs @@ -43,13 +96,70 @@ class NEAT(BaseAlgorithm): def num_outputs(self): return self.genome.num_outputs - @property - def pop_size(self): - return self.species.pop_size + def _create_next_generation(self, state, winner, loser, elite_mask): - def member_count(self, state: State): - return state.member_count + # find next node key for mutation + all_nodes_keys = state.pop_nodes[:, :, 0] + max_node_key = jnp.max( + all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0 + ) + next_node_key = max_node_key + 1 + new_node_keys = jnp.arange(self.pop_size) + next_node_key - def generation(self, state: State): - # to analysis the algorithm - return state.generation + # prepare random keys + k1, k2, randkey = jax.random.split(state.randkey, 3) + crossover_randkeys = jax.random.split(k1, self.pop_size) + mutate_randkeys = jax.random.split(k2, self.pop_size) + + wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] + lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] + + # batch crossover + n_nodes, n_conns = vmap( + self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0) + )( + state, crossover_randkeys, wpn, wpc, lpn, lpc + ) # new_nodes, new_conns + + # batch mutation + m_n_nodes, m_n_conns = vmap( + self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) + )( + state, mutate_randkeys, n_nodes, n_conns, new_node_keys + ) # mutated_new_nodes, mutated_new_conns + + # elitism don't mutate + pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) + pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns) + + return state.update( + randkey=randkey, + pop_nodes=pop_nodes, + pop_conns=pop_conns, + ) + + def show_details(self, state, fitness): + member_count = jax.device_get(state.species.member_count) + species_sizes = [int(i) for i in member_count if i > 0] + + pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns]) + nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,) + conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,) + + max_node_cnt, min_node_cnt, mean_node_cnt = ( + max(nodes_cnt), + min(nodes_cnt), + np.mean(nodes_cnt), + ) + + max_conn_cnt, min_conn_cnt, mean_conn_cnt = ( + max(conns_cnt), + min(conns_cnt), + np.mean(conns_cnt), + ) + + print( + f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n", + f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n", + f"\tspecies: {len(species_sizes)}, {species_sizes}\n", + ) diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species.py similarity index 52% rename from tensorneat/algorithm/neat/species/default.py rename to tensorneat/algorithm/neat/species.py index 3c5e82a..0537b70 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species.py @@ -1,54 +1,35 @@ -import jax, jax.numpy as jnp +from typing import Callable + +import jax +from jax import vmap, numpy as jnp +import numpy as np -from .base import BaseSpecies from tensorneat.common import ( State, + StatefulBaseClass, rank_elements, argmin_with_mask, fetch_first, ) -from tensorneat.genome.utils import ( - extract_conn_attrs, - extract_node_attrs, -) -from tensorneat.genome import BaseGenome -""" -Core procedures of NEAT algorithm, contains the following steps: -1. Update the fitness of each species; -2. Decide which species will be stagnation; -3. Decide the number of members of each species in the next generation; -4. Choice the crossover pair for each species; -5. Divided the whole new population into different species; - -This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python. -The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases. -""" - - -class DefaultSpecies(BaseSpecies): +class SpeciesController(StatefulBaseClass): def __init__( self, - genome: BaseGenome, pop_size, species_size, - compatibility_disjoint: float = 1.0, - compatibility_weight: float = 0.4, - max_stagnation: int = 15, - species_elitism: int = 2, - spawn_number_change_rate: float = 0.5, - genome_elitism: int = 2, - survival_threshold: float = 0.2, - min_species_size: int = 1, - compatibility_threshold: float = 3.0, + max_stagnation, + species_elitism, + spawn_number_change_rate, + genome_elitism, + survival_threshold, + min_species_size, + compatibility_threshold, + species_fitness_func, ): - self.genome = genome self.pop_size = pop_size self.species_size = species_size - - self.compatibility_disjoint = compatibility_disjoint - self.compatibility_weight = compatibility_weight + self.species_arange = np.arange(self.species_size) self.max_stagnation = max_stagnation self.species_elitism = species_elitism self.spawn_number_change_rate = spawn_number_change_rate @@ -56,42 +37,33 @@ class DefaultSpecies(BaseSpecies): self.survival_threshold = survival_threshold self.min_species_size = min_species_size self.compatibility_threshold = compatibility_threshold + self.species_fitness_func = species_fitness_func - self.species_arange = jnp.arange(self.species_size) + def setup(self, state, first_nodes, first_conns): + # the unique index (primary key) for each species + species_keys = jnp.full((self.species_size,), jnp.nan) - def setup(self, state=State()): - state = self.genome.setup(state) - k1, randkey = jax.random.split(state.randkey, 2) + # the best fitness of each species + best_fitness = jnp.full((self.species_size,), jnp.nan) - # initialize the population - initialize_keys = jax.random.split(randkey, self.pop_size) - pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))( - state, initialize_keys - ) + # the last 1 that the species improved + last_improved = jnp.full((self.species_size,), jnp.nan) - species_keys = jnp.full( - (self.species_size,), jnp.nan - ) # the unique index (primary key) for each species - best_fitness = jnp.full( - (self.species_size,), jnp.nan - ) # the best fitness of each species - last_improved = jnp.full( - (self.species_size,), jnp.nan - ) # the last 1 that the species improved - member_count = jnp.full( - (self.species_size,), jnp.nan - ) # the number of members of each species - idx2species = jnp.zeros(self.pop_size) # the species index of each individual + # the number of members of each species + member_count = jnp.full((self.species_size,), jnp.nan) + + # the species index of each individual + idx2species = jnp.zeros(self.pop_size) # nodes for each center genome of each species center_nodes = jnp.full( - (self.species_size, self.genome.max_nodes, self.genome.node_gene.length), + (self.species_size, *first_nodes.shape), jnp.nan, ) # connections for each center genome of each species center_conns = jnp.full( - (self.species_size, self.genome.max_conns, self.genome.conn_gene.length), + (self.species_size, *first_conns.shape), jnp.nan, ) @@ -99,16 +71,10 @@ class DefaultSpecies(BaseSpecies): best_fitness = best_fitness.at[0].set(-jnp.inf) last_improved = last_improved.at[0].set(0) member_count = member_count.at[0].set(self.pop_size) - center_nodes = center_nodes.at[0].set(pop_nodes[0]) - center_conns = center_conns.at[0].set(pop_conns[0]) + center_nodes = center_nodes.at[0].set(first_nodes) + center_conns = center_conns.at[0].set(first_conns) - pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) - - state = state.update(randkey=randkey) - - return state.register( - pop_nodes=pop_nodes, - pop_conns=pop_conns, + species_state = State( species_keys=species_keys, best_fitness=best_fitness, last_improved=last_improved, @@ -117,53 +83,50 @@ class DefaultSpecies(BaseSpecies): center_nodes=center_nodes, center_conns=center_conns, next_species_key=jnp.float32(1), # 0 is reserved for the first species - generation=jnp.float32(0), ) - def ask(self, state): - return state.pop_nodes, state.pop_conns - - def tell(self, state, fitness): - k1, k2, randkey = jax.random.split(state.randkey, 3) - - state = state.update(generation=state.generation + 1, randkey=randkey) - state, winner, loser, elite_mask = self.update_species(state, fitness) - state = self.create_next_generation(state, winner, loser, elite_mask) - state = self.speciate(state) - - return state + return state.register(species=species_state) def update_species(self, state, fitness): + species_state = state.species + # update the fitness of each species - state, species_fitness = self.update_species_fitness(state, fitness) + species_fitness = self._update_species_fitness(species_state, fitness) # stagnation species - state, species_fitness = self.stagnation(state, species_fitness) + species_state, species_fitness = self._stagnation( + species_state, species_fitness, state.generation + ) # sort species_info by their fitness. (also push nan to the end) - sort_indices = jnp.argsort(species_fitness)[::-1] + sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low - state = state.update( - species_keys=state.species_keys[sort_indices], - best_fitness=state.best_fitness[sort_indices], - last_improved=state.last_improved[sort_indices], - member_count=state.member_count[sort_indices], - center_nodes=state.center_nodes[sort_indices], - center_conns=state.center_conns[sort_indices], + species_state = species_state.update( + species_keys=species_state.species_keys[sort_indices], + best_fitness=species_state.best_fitness[sort_indices], + last_improved=species_state.last_improved[sort_indices], + member_count=species_state.member_count[sort_indices], + center_nodes=species_state.center_nodes[sort_indices], + center_conns=species_state.center_conns[sort_indices], ) # decide the number of members of each species by their fitness - state, spawn_number = self.cal_spawn_numbers(state) + spawn_number = self._cal_spawn_numbers(species_state) k1, k2 = jax.random.split(state.randkey) # crossover info - state, winner, loser, elite_mask = self.create_crossover_pair( - state, spawn_number, fitness + winner, loser, elite_mask = self._create_crossover_pair( + species_state, k1, spawn_number, fitness ) - return state.update(randkey=k2), winner, loser, elite_mask + return ( + state.update(randkey=k2, species=species_state), + winner, + loser, + elite_mask, + ) - def update_species_fitness(self, state, fitness): + def _update_species_fitness(self, species_state, fitness): """ obtain the fitness of the species by the fitness of each individual. use max criterion. @@ -171,14 +134,16 @@ class DefaultSpecies(BaseSpecies): def aux_func(idx): s_fitness = jnp.where( - state.idx2species == state.species_keys[idx], fitness, -jnp.inf + species_state.idx2species == species_state.species_keys[idx], + fitness, + -jnp.inf, ) - val = jnp.max(s_fitness) + val = self.species_fitness_func(s_fitness) return val - return state, jax.vmap(aux_func)(self.species_arange) + return vmap(aux_func)(self.species_arange) - def stagnation(self, state, species_fitness): + def _stagnation(self, species_state, species_fitness, generation): """ stagnation species. those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. @@ -187,28 +152,36 @@ class DefaultSpecies(BaseSpecies): def check_stagnation(idx): # determine whether the species stagnation - st = ( - species_fitness[idx] <= state.best_fitness[idx] - ) & ( # not better than the best fitness of the species - state.generation - state.last_improved[idx] > self.max_stagnation - ) # for a long time + + # not better than the best fitness of the species + # for a long time + st = (species_fitness[idx] <= species_state.best_fitness[idx]) & ( + generation - species_state.last_improved[idx] > self.max_stagnation + ) # update last_improved and best_fitness + # whether better than the best fitness of the species li, bf = jax.lax.cond( - species_fitness[idx] > state.best_fitness[idx], - lambda: (state.generation, species_fitness[idx]), # update + species_fitness[idx] > species_state.best_fitness[idx], + lambda: (generation, species_fitness[idx]), # update lambda: ( - state.last_improved[idx], - state.best_fitness[idx], + species_state.last_improved[idx], + species_state.best_fitness[idx], ), # not update ) return st, bf, li - spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)( + spe_st, best_fitness, last_improved = vmap(check_stagnation)( self.species_arange ) + # update species state + species_state = species_state.update( + best_fitness=best_fitness, + last_improved=last_improved, + ) + # elite species will not be stagnation species_rank = rank_elements(species_fitness) spe_st = jnp.where( @@ -224,18 +197,18 @@ class DefaultSpecies(BaseSpecies): jnp.nan, # best_fitness jnp.nan, # last_improved jnp.nan, # member_count + jnp.full_like(species_state.center_nodes[idx], jnp.nan), + jnp.full_like(species_state.center_conns[idx], jnp.nan), -jnp.inf, # species_fitness - jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes - jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns ), # stagnation species lambda: ( - state.species_keys[idx], - best_fitness[idx], - last_improved[idx], - state.member_count[idx], + species_state.species_keys[idx], + species_state.best_fitness[idx], + species_state.last_improved[idx], + species_state.member_count[idx], + species_state.center_nodes[idx], + species_state.center_conns[idx], species_fitness[idx], - state.center_nodes[idx], - state.center_conns[idx], ), # not stagnation species ) @@ -244,13 +217,13 @@ class DefaultSpecies(BaseSpecies): best_fitness, last_improved, member_count, - species_fitness, center_nodes, center_conns, - ) = jax.vmap(update_func)(self.species_arange) + species_fitness, + ) = vmap(update_func)(self.species_arange) return ( - state.update( + species_state.update( species_keys=species_keys, best_fitness=best_fitness, last_improved=last_improved, @@ -261,7 +234,7 @@ class DefaultSpecies(BaseSpecies): species_fitness, ) - def cal_spawn_numbers(self, state): + def _cal_spawn_numbers(self, species_state): """ decide the number of members of each species by their fitness rank. the species with higher fitness will have more members @@ -269,7 +242,7 @@ class DefaultSpecies(BaseSpecies): e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2] """ - species_keys = state.species_keys + species_keys = species_state.species_keys is_species_valid = ~jnp.isnan(species_keys) valid_species_num = jnp.sum(is_species_valid) @@ -288,7 +261,7 @@ class DefaultSpecies(BaseSpecies): ) # calculate member # Avoid too much variation of numbers for a species - previous_size = state.member_count + previous_size = species_state.member_count spawn_number = ( previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate @@ -301,14 +274,17 @@ class DefaultSpecies(BaseSpecies): # add error to the first species to control the sum of spawn_number spawn_number = spawn_number.at[0].add(error) - return state, spawn_number + return spawn_number - def create_crossover_pair(self, state, spawn_number, fitness): + def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness): s_idx = self.species_arange p_idx = jnp.arange(self.pop_size) def aux_func(key, idx): - members = state.idx2species == state.species_keys[idx] + # choose parents from the in the same species + # key -> randkey, idx -> the idx of current species + + members = species_state.idx2species == species_state.species_keys[idx] members_num = jnp.sum(members) members_fitness = jnp.where(members, fitness, -jnp.inf) @@ -333,11 +309,16 @@ class DefaultSpecies(BaseSpecies): elite = jnp.where(p_idx < self.genome_elitism, True, False) return fa, ma, elite - randkey_, randkey = jax.random.split(state.randkey) - fas, mas, elites = jax.vmap(aux_func)( - jax.random.split(randkey_, self.species_size), s_idx + # choose parents to crossover in each species + # fas, mas, elites: (self.species_size, self.pop_size) + # fas -> father indices, mas -> mother indices, elites -> whether elite or not + fas, mas, elites = vmap(aux_func)( + jax.random.split(randkey, self.species_size), s_idx ) + # merge choosen parents from each species into one array + # winner, loser, elite_mask: (self.pop_size) + # winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not spawn_number_cum = jnp.cumsum(spawn_number) def aux_func(idx): @@ -351,18 +332,18 @@ class DefaultSpecies(BaseSpecies): elites[loc, idx_in_species], ) - part1, part2, elite_mask = jax.vmap(aux_func)(p_idx) + part1, part2, elite_mask = vmap(aux_func)(p_idx) is_part1_win = fitness[part1] >= fitness[part2] winner = jnp.where(is_part1_win, part1, part2) loser = jnp.where(is_part1_win, part2, part1) - return state.update(randkey=randkey), winner, loser, elite_mask + return winner, loser, elite_mask - def speciate(self, state): + def speciate(self, state, genome_distance_func: Callable): # prepare distance functions - o2p_distance_func = jax.vmap( - self.distance, in_axes=(None, None, None, 0, 0) + o2p_distance_func = vmap( + genome_distance_func, in_axes=(None, None, None, 0, 0) ) # one to population # idx to specie key @@ -379,7 +360,7 @@ class DefaultSpecies(BaseSpecies): i, i2s, cns, ccs, o2c = carry return (i < self.species_size) & ( - ~jnp.isnan(state.species_keys[i]) + ~jnp.isnan(state.species.species_keys[i]) ) # current species is existing def body_func(carry): @@ -392,7 +373,7 @@ class DefaultSpecies(BaseSpecies): # find the closest one closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - i2s = i2s.at[closest_idx].set(state.species_keys[i]) + i2s = i2s.at[closest_idx].set(state.species.species_keys[i]) cns = cns.at[i].set(state.pop_nodes[closest_idx]) ccs = ccs.at[i].set(state.pop_conns[closest_idx]) @@ -404,13 +385,21 @@ class DefaultSpecies(BaseSpecies): _, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop( cond_func, body_func, - (0, idx2species, state.center_nodes, state.center_conns, o2c_distances), + ( + 0, + idx2species, + state.species.center_nodes, + state.species.center_conns, + o2c_distances, + ), ) state = state.update( - idx2species=idx2species, - center_nodes=center_nodes, - center_conns=center_conns, + species=state.species.update( + idx2species=idx2species, + center_nodes=center_nodes, + center_conns=center_conns, + ), ) # part 2: assign members to each species @@ -500,12 +489,12 @@ class DefaultSpecies(BaseSpecies): body_func, ( 0, - state.idx2species, + state.species.idx2species, center_nodes, center_conns, - state.species_keys, + state.species.species_keys, o2c_distances, - state.next_species_key, + state.species.next_species_key, ), ) @@ -514,10 +503,10 @@ class DefaultSpecies(BaseSpecies): idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) # complete info of species which is created in this generation - new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) - best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) + new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness) + best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness) last_improved = jnp.where( - new_created_mask, state.generation, state.last_improved + new_created_mask, state.generation, state.species.last_improved ) # update members count @@ -530,9 +519,9 @@ class DefaultSpecies(BaseSpecies): ), # count members ) - member_count = jax.vmap(count_members)(self.species_arange) + member_count = vmap(count_members)(self.species_arange) - return state.update( + species_state = state.species.update( species_keys=species_keys, best_fitness=best_fitness, last_improved=last_improved, @@ -543,135 +532,6 @@ class DefaultSpecies(BaseSpecies): next_species_key=next_species_key, ) - def distance(self, state, nodes1, conns1, nodes2, conns2): - """ - The distance between two genomes - """ - d = self.node_distance(state, nodes1, nodes2) + self.conn_distance( - state, conns1, conns2 - ) - return d - - def node_distance(self, state, nodes1, nodes2): - """ - The distance of the nodes part for two genomes - """ - node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) - node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) - max_cnt = jnp.maximum(node_cnt1, node_cnt2) - - # align homologous nodes - # this process is similar to np.intersect1d. - nodes = jnp.concatenate((nodes1, nodes2), axis=0) - keys = nodes[:, 0] - sorted_indices = jnp.argsort(keys, axis=0) - nodes = nodes[sorted_indices] - nodes = jnp.concatenate( - [nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0 - ) # add a nan row to the end - fr, sr = nodes[:-1], nodes[1:] # first row, second row - - # flag location of homologous nodes - intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) - - # calculate the count of non_homologous of two genomes - non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) - - # calculate the distance of homologous nodes - fr_attrs = jax.vmap(extract_node_attrs)(fr) - sr_attrs = jax.vmap(extract_node_attrs)(sr) - hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))( - state, fr_attrs, sr_attrs - ) # homologous node distance - hnd = jnp.where(jnp.isnan(hnd), 0, hnd) - homologous_distance = jnp.sum(hnd * intersect_mask) - - val = ( - non_homologous_cnt * self.compatibility_disjoint - + homologous_distance * self.compatibility_weight - ) - - val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize - - return val - - def conn_distance(self, state, conns1, conns2): - """ - The distance of the conns part for two genomes - """ - con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0])) - con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0])) - max_cnt = jnp.maximum(con_cnt1, con_cnt2) - - cons = jnp.concatenate((conns1, conns2), axis=0) - keys = cons[:, :2] - sorted_indices = jnp.lexsort(keys.T[::-1]) - cons = cons[sorted_indices] - cons = jnp.concatenate( - [cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0 - ) # add a nan row to the end - fr, sr = cons[:-1], cons[1:] # first row, second row - - # both genome has such connection - intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) - - non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - - fr_attrs = jax.vmap(extract_conn_attrs)(fr) - sr_attrs = jax.vmap(extract_conn_attrs)(sr) - hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))( - state, fr_attrs, sr_attrs - ) # homologous connection distance - hcd = jnp.where(jnp.isnan(hcd), 0, hcd) - homologous_distance = jnp.sum(hcd * intersect_mask) - - val = ( - non_homologous_cnt * self.compatibility_disjoint - + homologous_distance * self.compatibility_weight - ) - - val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize - - return val - - def create_next_generation(self, state, winner, loser, elite_mask): - - # find next node key - all_nodes_keys = state.pop_nodes[:, :, 0] - max_node_key = jnp.max( - all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0 - ) - next_node_key = max_node_key + 1 - new_node_keys = jnp.arange(self.pop_size) + next_node_key - - # prepare random keys - k1, k2, randkey = jax.random.split(state.randkey, 3) - crossover_randkeys = jax.random.split(k1, self.pop_size) - mutate_randkeys = jax.random.split(k2, self.pop_size) - - wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] - lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] - - # batch crossover - n_nodes, n_conns = jax.vmap( - self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0) - )( - state, crossover_randkeys, wpn, wpc, lpn, lpc - ) # new_nodes, new_conns - - # batch mutation - m_n_nodes, m_n_conns = jax.vmap( - self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) - )( - state, mutate_randkeys, n_nodes, n_conns, new_node_keys - ) # mutated_new_nodes, mutated_new_conns - - # elitism don't mutate - pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) - pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns) - return state.update( - randkey=randkey, - pop_nodes=pop_nodes, - pop_conns=pop_conns, + species=species_state, ) diff --git a/tensorneat/algorithm/neat/species/__init__.py b/tensorneat/algorithm/neat/species/__init__.py deleted file mode 100644 index f52a178..0000000 --- a/tensorneat/algorithm/neat/species/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import BaseSpecies -from .default import DefaultSpecies diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py deleted file mode 100644 index cf03d4f..0000000 --- a/tensorneat/algorithm/neat/species/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from tensorneat.common import State, StatefulBaseClass -from tensorneat.genome import BaseGenome - - -class BaseSpecies(StatefulBaseClass): - genome: BaseGenome - pop_size: int - species_size: int - - def ask(self, state: State): - raise NotImplementedError - - def tell(self, state: State, fitness): - raise NotImplementedError - - def update_species(self, state, fitness): - raise NotImplementedError - - def speciate(self, state): - raise NotImplementedError diff --git a/tensorneat/genome/__init__.py b/tensorneat/genome/__init__.py index 5c10584..6e00fa1 100644 --- a/tensorneat/genome/__init__.py +++ b/tensorneat/genome/__init__.py @@ -1,3 +1,5 @@ +from .gene import * +from .operations import * from .base import BaseGenome from .default import DefaultGenome from .recurrent import RecurrentGenome diff --git a/tensorneat/genome/gene/node/default.py b/tensorneat/genome/gene/node/default.py index 25b4193..10a0c32 100644 --- a/tensorneat/genome/gene/node/default.py +++ b/tensorneat/genome/gene/node/default.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Union, Sequence, Callable import numpy as np import jax, jax.numpy as jnp @@ -34,14 +34,20 @@ class DefaultNodeGene(BaseNodeGene): response_mutate_power: float = 0.5, response_mutate_rate: float = 0.7, response_replace_rate: float = 0.1, - aggregation_default: callable = Agg.sum, - aggregation_options: Tuple = (Agg.sum,), + aggregation_default: Callable = Agg.sum, + aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum, aggregation_replace_rate: float = 0.1, - activation_default: callable = Act.sigmoid, - activation_options: Tuple = (Act.sigmoid,), + activation_default: Callable = Act.sigmoid, + activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid, activation_replace_rate: float = 0.1, ): super().__init__() + + if isinstance(aggregation_options, Callable): + aggregation_options = [aggregation_options] + if isinstance(activation_options, Callable): + activation_options = [activation_options] + self.bias_init_mean = bias_init_mean self.bias_init_std = bias_init_std self.bias_mutate_power = bias_mutate_power diff --git a/tensorneat/genome/operations/distance/default.py b/tensorneat/genome/operations/distance/default.py index 3e78916..d992d2e 100644 --- a/tensorneat/genome/operations/distance/default.py +++ b/tensorneat/genome/operations/distance/default.py @@ -13,7 +13,7 @@ class DefaultDistance(BaseDistance): self.compatibility_disjoint = compatibility_disjoint self.compatibility_weight = compatibility_weight - def __call__(self, state, nodes1, nodes2, conns1, conns2): + def __call__(self, state, nodes1, conns1, nodes2, conns2): """ The distance between two genomes """ diff --git a/tensorneat/genome/operations/mutation/base.py b/tensorneat/genome/operations/mutation/base.py index 15c0d4a..2d138af 100644 --- a/tensorneat/genome/operations/mutation/base.py +++ b/tensorneat/genome/operations/mutation/base.py @@ -8,5 +8,5 @@ class BaseMutation(StatefulBaseClass): self.genome = genome return state - def __call__(self, state, randkey, genome, nodes, conns, new_node_key): + def __call__(self, state, randkey, nodes, conns, new_node_key): raise NotImplementedError diff --git a/tensorneat/genome/operations/mutation/default.py b/tensorneat/genome/operations/mutation/default.py index e7100bc..efcf765 100644 --- a/tensorneat/genome/operations/mutation/default.py +++ b/tensorneat/genome/operations/mutation/default.py @@ -33,17 +33,17 @@ class DefaultMutation(BaseMutation): self.node_add = node_add self.node_delete = node_delete - def __call__(self, state, randkey, genome, nodes, conns, new_node_key): + def __call__(self, state, randkey, nodes, conns, new_node_key): k1, k2 = jax.random.split(randkey) nodes, conns = self.mutate_structure( - state, k1, genome, nodes, conns, new_node_key + state, k1, nodes, conns, new_node_key ) - nodes, conns = self.mutate_values(state, k2, genome, nodes, conns) + nodes, conns = self.mutate_values(state, k2, nodes, conns) return nodes, conns - def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): + def mutate_structure(self, state, randkey, nodes, conns, new_node_key): def mutate_add_node(key_, nodes_, conns_): """ add a node while do not influence the output of the network @@ -62,7 +62,7 @@ class DefaultMutation(BaseMutation): # add a new node with identity attrs new_nodes = add_node( - nodes_, new_node_key, genome.node_gene.new_identity_attrs(state) + nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state) ) # add two new connections @@ -71,7 +71,7 @@ class DefaultMutation(BaseMutation): new_conns, i_key, new_node_key, - genome.conn_gene.new_identity_attrs(state), + self.genome.conn_gene.new_identity_attrs(state), ) # second is with the origin attrs new_conns = add_conn( @@ -97,8 +97,8 @@ class DefaultMutation(BaseMutation): key, idx = self.choose_node_key( key_, nodes_, - genome.input_idx, - genome.output_idx, + self.genome.input_idx, + self.genome.output_idx, allow_input_keys=False, allow_output_keys=False, ) @@ -136,8 +136,8 @@ class DefaultMutation(BaseMutation): i_key, from_idx = self.choose_node_key( k1_, nodes_, - genome.input_idx, - genome.output_idx, + self.genome.input_idx, + self.genome.output_idx, allow_input_keys=True, allow_output_keys=True, ) @@ -146,8 +146,8 @@ class DefaultMutation(BaseMutation): o_key, to_idx = self.choose_node_key( k2_, nodes_, - genome.input_idx, - genome.output_idx, + self.genome.input_idx, + self.genome.output_idx, allow_input_keys=False, allow_output_keys=True, ) @@ -161,10 +161,10 @@ class DefaultMutation(BaseMutation): def successful(): # add a connection with zero attrs return nodes_, add_conn( - conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state) + conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state) ) - if genome.network_type == "feedforward": + if self.genome.network_type == "feedforward": u_conns = unflatten_conns(nodes_, conns_) conns_exist = u_conns != I_INF is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) @@ -175,7 +175,7 @@ class DefaultMutation(BaseMutation): successful, ) - elif genome.network_type == "recurrent": + elif self.genome.network_type == "recurrent": return jax.lax.cond( is_already_exist | (remain_conn_space < 1), nothing, @@ -183,7 +183,7 @@ class DefaultMutation(BaseMutation): ) else: - raise ValueError(f"Invalid network type: {genome.network_type}") + raise ValueError(f"Invalid network type: {self.genome.network_type}") def mutate_delete_conn(key_, nodes_, conns_): # randomly choose a connection @@ -223,19 +223,19 @@ class DefaultMutation(BaseMutation): return nodes, conns - def mutate_values(self, state, randkey, genome, nodes, conns): + def mutate_values(self, state, randkey, nodes, conns): k1, k2 = jax.random.split(randkey) - nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) - conns_randkeys = jax.random.split(k2, num=genome.max_conns) + nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes) + conns_randkeys = jax.random.split(k2, num=self.genome.max_conns) node_attrs = vmap(extract_node_attrs)(nodes) - new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( + new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))( state, nodes_randkeys, node_attrs ) new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) conn_attrs = vmap(extract_conn_attrs)(conns) - new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( + new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))( state, conns_randkeys, conn_attrs ) new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 5c2bae3..3de5e9f 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -5,10 +5,10 @@ import jax, jax.numpy as jnp import datetime, time import numpy as np -from algorithm import BaseAlgorithm -from problem import BaseProblem -from problem.rl_env import RLEnv -from problem.func_fit import FuncFit +from tensorneat.algorithm import BaseAlgorithm +from tensorneat.problem import BaseProblem +from tensorneat.problem.rl_env import RLEnv +from tensorneat.problem.func_fit import FuncFit from tensorneat.common import State, StatefulBaseClass @@ -187,7 +187,7 @@ class Pipeline(StatefulBaseClass): print("Fitness limit reached!") break - if self.algorithm.generation(state) >= self.generation_limit: + if int(state.generation) >= self.generation_limit: print("Generation limit reached!") if self.is_save: @@ -203,6 +203,8 @@ class Pipeline(StatefulBaseClass): return state, self.best_genome def analysis(self, state, pop, fitnesses): + + generation = int(state.generation) valid_fitnesses = fitnesses[~np.isinf(fitnesses)] @@ -223,8 +225,12 @@ class Pipeline(StatefulBaseClass): self.best_genome = pop[0][max_idx], pop[1][max_idx] if self.is_save: + # save best best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) - with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: + file_name = os.path.join( + self.genome_dir, f"{generation}.npz" + ) + with open(file_name, "wb") as f: np.savez( f, nodes=best_genome[0], @@ -232,42 +238,18 @@ class Pipeline(StatefulBaseClass): fitness=self.best_fitness, ) - # save best if save path is not None - - member_count = jax.device_get(self.algorithm.member_count(state)) - species_sizes = [int(i) for i in member_count if i > 0] - - pop = jax.device_get(pop) - pop_nodes, pop_conns = pop # (P, N, NL), (P, C, CL) - nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,) - conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,) - - max_node_cnt, min_node_cnt, mean_node_cnt = ( - max(nodes_cnt), - min(nodes_cnt), - np.mean(nodes_cnt), - ) - - max_conn_cnt, min_conn_cnt, mean_conn_cnt = ( - max(conns_cnt), - min(conns_cnt), - np.mean(conns_cnt), - ) + # append log + with open(os.path.join(self.save_dir, "log.txt"), "a") as f: + f.write( + f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" + ) print( - f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n", - f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n", - f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n", - f"\tspecies: {len(species_sizes)}, {species_sizes}\n", + f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", ) - # append log - if self.is_save: - with open(os.path.join(self.save_dir, "log.txt"), "a") as f: - f.write( - f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" - ) + self.algorithm.show_details(state, fitnesses) def show(self, state, best, *args, **kwargs): transformed = self.algorithm.transform(state, best)