modify method cal_spawn_numbers

spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
wls2002
2023-07-01 13:36:19 +08:00
parent 896082900a
commit f6dcb97df8
7 changed files with 64 additions and 21 deletions

View File

@@ -2,7 +2,7 @@
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate
from .population import update_species, create_next_generation, speciate, tell
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -100,4 +100,4 @@ def create_forward_function(config):
elif config['forward_way'] == 'common':
return jit(common_forward)
return forward
return jit(forward)

View File

@@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
@jit
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
jit_config):
generation += 1
k1, k2, randkey = jax.random.split(randkey, 3)
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, generation, jit_config)
idx2species, center_nodes, center_cons, species_info = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
@@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config):
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
# Avoid too much variation of numbers in a species
previous_size = species_info[:, 3].astype(jnp.int32)
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)

View File

@@ -44,7 +44,8 @@ jit_config_keys = [
"pop_size",
"genome_elitism",
"survival_threshold",
"species_elitism"
"species_elitism",
"spawn_number_move_rate"
]

View File

@@ -2,7 +2,7 @@
num_inputs = 2
num_outputs = 1
init_maximum_nodes = 50
init_maximum_connections = 200
init_maximum_connections = 50
init_maximum_species = 10
expand_coe = 1.5
pre_expand_threshold = 0.75
@@ -13,7 +13,7 @@ batch_size = 4
fitness_threshold = 100000
generation_limit = 1000
fitness_criterion = "max"
pop_size = 150
pop_size = 2000
[genome]
compatibility_disjoint = 1.0
@@ -31,6 +31,7 @@ max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0

View File

@@ -1,3 +1,4 @@
import jax
import numpy as np
from configs import Configer
@@ -14,8 +15,9 @@ def evaluate(forward_func):
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return np.array(fitnesses) # returns a list
return fitnesses
def main():

View File

@@ -48,6 +48,26 @@ class Pipeline:
self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
# fitness_lower = np.zeros(self.P, dtype=np.float32)
# randkey_lower = np.zeros(2, dtype=np.uint32)
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
#
# self.tell_func = jit(neat.tell).lower(fitness_lower,
# randkey_lower,
# pop_nodes_lower,
# pop_cons_lower,
# species_info_lower,
# idx2species_lower,
# center_nodes_lower,
# center_cons_lower,
# 0,
# self.jit_config).compile()
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
@@ -75,21 +95,12 @@ class Pipeline:
assert self.config['forward_way'] == 'common'
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
def tell(self, fitnesses):
self.generation += 1
def tell(self, fitness):
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
elite_mask, self.generation, self.jit_config)
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
self.jit_config)
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):