modify method cal_spawn_numbers
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||
from .population import update_species, create_next_generation, speciate
|
||||
from .population import update_species, create_next_generation, speciate, tell
|
||||
|
||||
from .genome.activations import act_name2func
|
||||
from .genome.aggregations import agg_name2func
|
||||
|
||||
@@ -100,4 +100,4 @@ def create_forward_function(config):
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return forward
|
||||
return jit(forward)
|
||||
|
||||
@@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp
|
||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
jit_config):
|
||||
|
||||
generation += 1
|
||||
|
||||
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||
|
||||
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||
center_cons, generation, jit_config)
|
||||
|
||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, generation, jit_config)
|
||||
|
||||
idx2species, center_nodes, center_cons, species_info = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
@@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config):
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers in a species
|
||||
previous_size = species_info[:, 3].astype(jnp.int32)
|
||||
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
||||
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
|
||||
@@ -44,7 +44,8 @@ jit_config_keys = [
|
||||
"pop_size",
|
||||
"genome_elitism",
|
||||
"survival_threshold",
|
||||
"species_elitism"
|
||||
"species_elitism",
|
||||
"spawn_number_move_rate"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
init_maximum_nodes = 50
|
||||
init_maximum_connections = 200
|
||||
init_maximum_connections = 50
|
||||
init_maximum_species = 10
|
||||
expand_coe = 1.5
|
||||
pre_expand_threshold = 0.75
|
||||
@@ -13,7 +13,7 @@ batch_size = 4
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 150
|
||||
pop_size = 2000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
@@ -31,6 +31,7 @@ max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_move_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
bias_init_mean = 0.0
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
@@ -14,8 +15,9 @@ def evaluate(forward_func):
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return np.array(fitnesses) # returns a list
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
37
pipeline.py
37
pipeline.py
@@ -48,6 +48,26 @@ class Pipeline:
|
||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||
self.forward = neat.create_forward_function(config)
|
||||
|
||||
# fitness_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# randkey_lower = np.zeros(2, dtype=np.uint32)
|
||||
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
|
||||
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
|
||||
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
|
||||
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
|
||||
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
|
||||
#
|
||||
# self.tell_func = jit(neat.tell).lower(fitness_lower,
|
||||
# randkey_lower,
|
||||
# pop_nodes_lower,
|
||||
# pop_cons_lower,
|
||||
# species_info_lower,
|
||||
# idx2species_lower,
|
||||
# center_nodes_lower,
|
||||
# center_cons_lower,
|
||||
# 0,
|
||||
# self.jit_config).compile()
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Creates a function that receives a genome and returns a forward function.
|
||||
@@ -75,22 +95,13 @@ class Pipeline:
|
||||
assert self.config['forward_way'] == 'common'
|
||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
def tell(self, fitness):
|
||||
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
||||
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
|
||||
self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
||||
elite_mask, self.generation, self.jit_config)
|
||||
|
||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
|
||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
||||
self.jit_config)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
forward_func = self.ask()
|
||||
|
||||
Reference in New Issue
Block a user