modify method cal_spawn_numbers
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||||
"""
|
"""
|
||||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||||
from .population import update_species, create_next_generation, speciate
|
from .population import update_species, create_next_generation, speciate, tell
|
||||||
|
|
||||||
from .genome.activations import act_name2func
|
from .genome.activations import act_name2func
|
||||||
from .genome.aggregations import agg_name2func
|
from .genome.aggregations import agg_name2func
|
||||||
|
|||||||
@@ -100,4 +100,4 @@ def create_forward_function(config):
|
|||||||
elif config['forward_way'] == 'common':
|
elif config['forward_way'] == 'common':
|
||||||
return jit(common_forward)
|
return jit(common_forward)
|
||||||
|
|
||||||
return forward
|
return jit(forward)
|
||||||
|
|||||||
@@ -11,6 +11,28 @@ from jax import jit, vmap, Array, numpy as jnp
|
|||||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||||
|
jit_config):
|
||||||
|
|
||||||
|
generation += 1
|
||||||
|
|
||||||
|
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||||
|
|
||||||
|
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||||
|
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||||
|
center_cons, generation, jit_config)
|
||||||
|
|
||||||
|
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||||
|
elite_mask, generation, jit_config)
|
||||||
|
|
||||||
|
idx2species, center_nodes, center_cons, species_info = speciate(
|
||||||
|
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||||
|
jit_config)
|
||||||
|
|
||||||
|
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +132,13 @@ def cal_spawn_numbers(species_info, jit_config):
|
|||||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||||
|
|
||||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
target_spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']) # calculate member
|
||||||
|
|
||||||
|
# Avoid too much variation of numbers in a species
|
||||||
|
previous_size = species_info[:, 3].astype(jnp.int32)
|
||||||
|
spawn_number = previous_size + (target_spawn_number - previous_size) * jit_config['spawn_number_move_rate']
|
||||||
|
|
||||||
|
spawn_number = spawn_number.astype(jnp.int32)
|
||||||
|
|
||||||
# must control the sum of spawn_number to be equal to pop_size
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ jit_config_keys = [
|
|||||||
"pop_size",
|
"pop_size",
|
||||||
"genome_elitism",
|
"genome_elitism",
|
||||||
"survival_threshold",
|
"survival_threshold",
|
||||||
"species_elitism"
|
"species_elitism",
|
||||||
|
"spawn_number_move_rate"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 50
|
init_maximum_nodes = 50
|
||||||
init_maximum_connections = 200
|
init_maximum_connections = 50
|
||||||
init_maximum_species = 10
|
init_maximum_species = 10
|
||||||
expand_coe = 1.5
|
expand_coe = 1.5
|
||||||
pre_expand_threshold = 0.75
|
pre_expand_threshold = 0.75
|
||||||
@@ -13,7 +13,7 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 150
|
pop_size = 2000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
@@ -31,6 +31,7 @@ max_stagnation = 15
|
|||||||
genome_elitism = 2
|
genome_elitism = 2
|
||||||
survival_threshold = 0.2
|
survival_threshold = 0.2
|
||||||
min_species_size = 1
|
min_species_size = 1
|
||||||
|
spawn_number_move_rate = 0.5
|
||||||
|
|
||||||
[gene-bias]
|
[gene-bias]
|
||||||
bias_init_mean = 0.0
|
bias_init_mean = 0.0
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
@@ -14,8 +15,9 @@ def evaluate(forward_func):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
outs = forward_func(xor_inputs)
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
return np.array(fitnesses) # returns a list
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
39
pipeline.py
39
pipeline.py
@@ -48,6 +48,26 @@ class Pipeline:
|
|||||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||||
self.forward = neat.create_forward_function(config)
|
self.forward = neat.create_forward_function(config)
|
||||||
|
|
||||||
|
# fitness_lower = np.zeros(self.P, dtype=np.float32)
|
||||||
|
# randkey_lower = np.zeros(2, dtype=np.uint32)
|
||||||
|
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
|
||||||
|
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
|
||||||
|
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
|
||||||
|
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
|
||||||
|
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
|
||||||
|
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
|
||||||
|
#
|
||||||
|
# self.tell_func = jit(neat.tell).lower(fitness_lower,
|
||||||
|
# randkey_lower,
|
||||||
|
# pop_nodes_lower,
|
||||||
|
# pop_cons_lower,
|
||||||
|
# species_info_lower,
|
||||||
|
# idx2species_lower,
|
||||||
|
# center_nodes_lower,
|
||||||
|
# center_cons_lower,
|
||||||
|
# 0,
|
||||||
|
# self.jit_config).compile()
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
"""
|
"""
|
||||||
Creates a function that receives a genome and returns a forward function.
|
Creates a function that receives a genome and returns a forward function.
|
||||||
@@ -75,21 +95,12 @@ class Pipeline:
|
|||||||
assert self.config['forward_way'] == 'common'
|
assert self.config['forward_way'] == 'common'
|
||||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
def tell(self, fitnesses):
|
def tell(self, fitness):
|
||||||
self.generation += 1
|
|
||||||
|
|
||||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||||
|
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
|
||||||
self.species_info, self.center_nodes, self.center_cons, winner, loser, elite_mask = \
|
self.species_info, self.idx2species, self.center_nodes,
|
||||||
neat.update_species(k1, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
self.center_cons, self.generation, self.jit_config)
|
||||||
self.center_cons, self.generation, self.jit_config)
|
|
||||||
|
|
||||||
self.pop_nodes, self.pop_cons = neat.create_next_generation(k2, self.pop_nodes, self.pop_cons, winner, loser,
|
|
||||||
elite_mask, self.generation, self.jit_config)
|
|
||||||
|
|
||||||
self.idx2species, self.center_nodes, self.center_cons, self.species_info = neat.speciate(
|
|
||||||
self.pop_nodes, self.pop_cons, self.species_info, self.center_nodes, self.center_cons, self.generation,
|
|
||||||
self.jit_config)
|
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config['generation_limit']):
|
for _ in range(self.config['generation_limit']):
|
||||||
|
|||||||
Reference in New Issue
Block a user