add column "members_count" to species_info
This commit is contained in:
@@ -1,6 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
contains operations on the population: creating the next generation and population speciation.
|
Contains operations on the population: creating the next generation and population speciation.
|
||||||
|
These im.....
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: Complete python doc
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import jit, vmap, Array, numpy as jnp
|
from jax import jit, vmap, Array, numpy as jnp
|
||||||
|
|
||||||
@@ -13,8 +17,8 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
|
|||||||
args:
|
args:
|
||||||
randkey: random key
|
randkey: random key
|
||||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||||
species_keys: Array[(species_size, 3), float], the information of each species
|
species_keys: Array[(species_size, 4), float], the information of each species
|
||||||
[species_key, best_score, last_update]
|
[species_key, best_score, last_update, members_count]
|
||||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||||
@@ -68,12 +72,12 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
|
|||||||
|
|
||||||
def aux_func(idx):
|
def aux_func(idx):
|
||||||
s_fitness = species_fitness[idx]
|
s_fitness = species_fitness[idx]
|
||||||
species_key, best_score, last_update = species_info[idx]
|
species_key, best_score, last_update, members_count = species_info[idx]
|
||||||
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||||
last_update = jnp.where(s_fitness > best_score, generation, last_update)
|
last_update = jnp.where(s_fitness > best_score, generation, last_update)
|
||||||
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
|
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
|
||||||
# stagnation condition
|
# stagnation condition
|
||||||
return st, jnp.array([species_key, best_score, last_update])
|
return st, jnp.array([species_key, best_score, last_update, members_count])
|
||||||
|
|
||||||
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||||
|
|
||||||
@@ -116,7 +120,6 @@ def cal_spawn_numbers(species_info, jit_config):
|
|||||||
|
|
||||||
|
|
||||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||||
|
|
||||||
species_size = species_info.shape[0]
|
species_size = species_info.shape[0]
|
||||||
pop_size = fitness.shape[0]
|
pop_size = fitness.shape[0]
|
||||||
s_idx = jnp.arange(species_size)
|
s_idx = jnp.arange(species_size)
|
||||||
@@ -257,7 +260,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
idx = fetch_first(jnp.isnan(i2s))
|
idx = fetch_first(jnp.isnan(i2s))
|
||||||
|
|
||||||
# assign it to the new species
|
# assign it to the new species
|
||||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation]
|
# [key, best score, last update generation, members_count]
|
||||||
|
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
||||||
i2s = i2s.at[idx].set(ck)
|
i2s = i2s.at[idx].set(ck)
|
||||||
|
|
||||||
# update center genomes
|
# update center genomes
|
||||||
@@ -296,6 +300,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||||
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
|
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
|
||||||
|
|
||||||
|
# update members count
|
||||||
|
def count_members(idx):
|
||||||
|
key = species_info[idx, 0]
|
||||||
|
count = jnp.sum(idx2specie == key)
|
||||||
|
count = jnp.where(jnp.isnan(key), jnp.nan, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
||||||
|
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||||
|
|
||||||
return idx2specie, center_nodes, center_cons, species_info
|
return idx2specie, center_nodes, center_cons, species_info
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 1500
|
pop_size = 150
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ class Pipeline:
|
|||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
|
|
||||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
||||||
self.species_info = np.full((self.S, 3), np.nan)
|
self.species_info = np.full((self.S, 4), np.nan)
|
||||||
self.species_info[0, :] = 0, -np.inf, 0
|
self.species_info[0, :] = 0, -np.inf, 0, self.P
|
||||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||||
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
||||||
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
||||||
@@ -128,5 +128,8 @@ class Pipeline:
|
|||||||
self.best_fitness = fitnesses[max_idx]
|
self.best_fitness = fitnesses[max_idx]
|
||||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||||
|
|
||||||
|
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0]
|
||||||
|
|
||||||
print(f"Generation: {self.generation}",
|
print(f"Generation: {self.generation}",
|
||||||
|
f"species: {len(species_sizes)}, {species_sizes}",
|
||||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||||
|
|||||||
Reference in New Issue
Block a user