add column "members_count" to species_info
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
Contains operations on the population: creating the next generation and population speciation.
|
||||
These im.....
|
||||
"""
|
||||
|
||||
# TODO: Complete python doc
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
@@ -13,8 +17,8 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
|
||||
args:
|
||||
randkey: random key
|
||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||
species_keys: Array[(species_size, 3), float], the information of each species
|
||||
[species_key, best_score, last_update]
|
||||
species_keys: Array[(species_size, 4), float], the information of each species
|
||||
[species_key, best_score, last_update, members_count]
|
||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||
@@ -68,12 +72,12 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = species_fitness[idx]
|
||||
species_key, best_score, last_update = species_info[idx]
|
||||
species_key, best_score, last_update, members_count = species_info[idx]
|
||||
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||
last_update = jnp.where(s_fitness > best_score, generation, last_update)
|
||||
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
|
||||
# stagnation condition
|
||||
return st, jnp.array([species_key, best_score, last_update])
|
||||
return st, jnp.array([species_key, best_score, last_update, members_count])
|
||||
|
||||
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
@@ -116,7 +120,6 @@ def cal_spawn_numbers(species_info, jit_config):
|
||||
|
||||
|
||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||
|
||||
species_size = species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
s_idx = jnp.arange(species_size)
|
||||
@@ -257,7 +260,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation]
|
||||
# [key, best score, last update generation, members_count]
|
||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
@@ -296,6 +300,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
key = species_info[idx, 0]
|
||||
count = jnp.sum(idx2specie == key)
|
||||
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||
return count
|
||||
|
||||
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_info
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ batch_size = 4
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 1500
|
||||
pop_size = 150
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
|
||||
@@ -30,8 +30,8 @@ class Pipeline:
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
||||
self.species_info = np.full((self.S, 3), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0
|
||||
self.species_info = np.full((self.S, 4), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0, self.P
|
||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
||||
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
||||
@@ -128,5 +128,8 @@ class Pipeline:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0]
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||
|
||||
Reference in New Issue
Block a user