add column "members_count" to species_info

This commit is contained in:
wls2002
2023-06-29 10:39:50 +08:00
parent f5c1ce72f9
commit 896082900a
3 changed files with 28 additions and 10 deletions

View File

@@ -1,6 +1,10 @@
"""
contains operations on the population: creating the next generation and population speciation.
Contains operations on the population: creating the next generation and population speciation.
These im.....
"""
# TODO: Complete python doc
import jax
from jax import jit, vmap, Array, numpy as jnp
@@ -13,8 +17,8 @@ def update_species(randkey, fitness, species_info, idx2species, center_nodes, ce
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 3), float], the information of each species
[species_key, best_score, last_update]
species_keys: Array[(species_size, 4), float], the information of each species
[species_key, best_score, last_update, members_count]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
@@ -68,12 +72,12 @@ def stagnation(species_fitness, species_info, center_nodes, center_cons, generat
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update = species_info[idx]
species_key, best_score, last_update, members_count = species_info[idx]
st = (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
last_update = jnp.where(s_fitness > best_score, generation, last_update)
best_score = jnp.where(s_fitness > best_score, s_fitness, best_score)
# stagnation condition
return st, jnp.array([species_key, best_score, last_update])
return st, jnp.array([species_key, best_score, last_update, members_count])
spe_st, species_info = vmap(aux_func)(jnp.arange(species_info.shape[0]))
@@ -116,7 +120,6 @@ def cal_spawn_numbers(species_info, jit_config):
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
@@ -257,7 +260,8 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation])) # [key, best score, last update generation]
# [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(ck)
# update center genomes
@@ -296,6 +300,17 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_info[-1, 0], idx2specie)
# update members count
def count_members(idx):
key = species_info[idx, 0]
count = jnp.sum(idx2specie == key)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts)
return idx2specie, center_nodes, center_cons, species_info