diff --git a/algorithms/neat/population.py b/algorithms/neat/population.py index aa49b0f..a521a9e 100644 --- a/algorithms/neat/population.py +++ b/algorithms/neat/population.py @@ -1,6 +1,10 @@ """ -contains operations on the population: creating the next generation and population speciation. +Contains operations on the population: creating the next generation and population speciation. +These im..... """ + +# TODO: Complete python doc + import jax from jax import jit, vmap, Array, numpy as jnp @@ -13,8 +17,8 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce args: randkey: random key fitness: Array[(pop_size,), float], the fitness of each individual - species_keys: Array[(species_size, 3), float], the information of each species - [species_key, best_score, last_update] + species_keys: Array[(species_size, 4), float], the information of each species + [species_key, best_score, last_update, members_count] idx2species: Array[(pop_size,), int], map the individual to its species center_nodes: Array[(species_size, N, 4), float], the center nodes of each species center_cons: Array[(species_size, C, 4), float], the center connections of each species @@ -68,12 +72,12 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat def aux_func(idx): s_fitness = species_fitness[idx] - species_key, best_score, last_update = species_info[idx] + species_key, best_score, last_update, members_count = species_info[idx] st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation']) last_update = jnp.where(s_fitness > best_score, generation, last_update) best_score = jnp.where(s_fitness > best_score, s_fitness, best_score) # stagnation condition - return st, jnp.array([species_key, best_score, last_update]) + return st, jnp.array([species_key, best_score, last_update, members_count]) spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0])) @@ -116,7 +120,6 @@ def cal_spawn_numbers(species_info, jit_config): def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config): - species_size = species_info.shape[0] pop_size = fitness.shape[0] s_idx = jnp.arange(species_size) @@ -257,7 +260,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener idx = fetch_first(jnp.isnan(i2s)) # assign it to the new species - si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation] + # [key, best score, last update generation, members_count] + si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0])) i2s = i2s.at[idx].set(ck) # update center genomes @@ -296,6 +300,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener # if there are still some pop genomes not assigned to any species, add them to the last genome # this condition seems to be only happened when the number of species is reached species upper bounds idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie) + + # update members count + def count_members(idx): + key = species_info[idx, 0] + count = jnp.sum(idx2specie == key) + count = jnp.where(jnp.isnan(key), jnp.nan, count) + return count + + species_member_counts = vmap(count_members)(jnp.arange(species_size)) + species_info = species_info.at[:, 3].set(species_member_counts) + return idx2specie, center_nodes, center_cons, species_info diff --git a/configs/default_config.ini b/configs/default_config.ini index 7bb4244..b85cba9 100644 --- a/configs/default_config.ini +++ b/configs/default_config.ini @@ -13,7 +13,7 @@ batch_size = 4 fitness_threshold = 100000 generation_limit = 1000 fitness_criterion = "max" -pop_size = 1500 +pop_size = 150 [genome] compatibility_disjoint = 1.0 diff --git a/pipeline.py b/pipeline.py index 8b1675f..8e9bf72 100644 --- a/pipeline.py +++ b/pipeline.py @@ -30,8 +30,8 @@ class Pipeline: self.best_genome = None self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config) - self.species_info = np.full((self.S, 3), np.nan) - self.species_info[0, :] = 0, -np.inf, 0 + self.species_info = np.full((self.S, 4), np.nan) + self.species_info[0, :] = 0, -np.inf, 0, self.P self.idx2species = np.zeros(self.P, dtype=np.float32) self.center_nodes = np.full((self.S, self.N, 5), np.nan) self.center_cons = np.full((self.S, self.C, 4), np.nan) @@ -128,5 +128,8 @@ class Pipeline: self.best_fitness = fitnesses[max_idx] self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) + species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0] + print(f"Generation: {self.generation}", + f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")