This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -2,30 +2,34 @@
Lowers, compiles, and creates functions used in the NEAT pipeline. Lowers, compiles, and creates functions used in the NEAT pipeline.
""" """
from functools import partial from functools import partial
import time
import jax.random
import numpy as np import numpy as np
from jax import jit, vmap from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover from .genome import act_name2key, agg_name2key, initialize_genomes
from .genome import topological_sort, forward_single, unflatten_connections from .genome import topological_sort, forward_single, unflatten_connections
from .population import create_next_generation_then_speciate
class FunctionFactory: class FunctionFactory:
def __init__(self, config, debug=False): def __init__(self, config):
self.config = config self.config = config
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.init_C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {} self.compiled_function = {}
self.time_cost = {}
self.load_config_vals(config) self.load_config_vals(config)
self.precompile()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
self.create_update_speciate_with_args()
def load_config_vals(self, config): def load_config_vals(self, config):
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.problem_batch = config.basic.problem_batch self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
@@ -79,12 +83,12 @@ class FunctionFactory:
self.delete_connection_rate = genome.conn_delete_prob self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self): def create_initialize(self, N, C):
func = partial( func = partial(
initialize_genomes, initialize_genomes,
pop_size=self.pop_size, pop_size=self.pop_size,
N=self.init_N, N=N,
C=self.init_C, C=C,
num_inputs=self.num_inputs, num_inputs=self.num_inputs,
num_outputs=self.num_outputs, num_outputs=self.num_outputs,
default_bias=self.bias_mean, default_bias=self.bias_mean,
@@ -93,166 +97,85 @@ class FunctionFactory:
default_agg=self.agg_default, default_agg=self.agg_default,
default_weight=self.weight_mean default_weight=self.weight_mean
) )
if self.debug:
def debug_initialize(*args):
return func(*args)
return debug_initialize
else:
return func return func
def precompile(self): def create_update_speciate_with_args(self):
self.create_mutate_with_args() species_kwargs = {
self.create_distance_with_args() "disjoint_coe": self.disjoint_coe,
self.create_crossover_with_args() "compatibility_coe": self.compatibility_coe,
self.create_topological_sort_with_args() "compatibility_threshold": self.compatibility_threshold
self.create_single_forward_with_args() }
#
# n, c = self.init_N, self.init_C
# print("start precompile")
# for _ in range(self.precompile_times):
# self.compile_mutate(n)
# self.compile_distance(n)
# self.compile_crossover(n)
# self.compile_topological_sort_batch(n)
# self.compile_pop_batch_forward(n)
# n = int(self.expand_coe * n)
#
# # precompile other functions used in jax
# key = jax.random.PRNGKey(0)
# _ = jax.random.split(key, 3)
# _ = jax.random.split(key, self.pop_size * 2)
# _ = jax.random.split(key, self.pop_size)
#
# print("end precompile")
def create_mutate_with_args(self): mutate_kwargs = {
func = partial( "input_idx": self.input_idx,
mutate, "output_idx": self.output_idx,
input_idx=self.input_idx, "bias_mean": self.bias_mean,
output_idx=self.output_idx, "bias_std": self.bias_std,
bias_mean=self.bias_mean, "bias_mutate_strength": self.bias_mutate_strength,
bias_std=self.bias_std, "bias_mutate_rate": self.bias_mutate_rate,
bias_mutate_strength=self.bias_mutate_strength, "bias_replace_rate": self.bias_replace_rate,
bias_mutate_rate=self.bias_mutate_rate, "response_mean": self.response_mean,
bias_replace_rate=self.bias_replace_rate, "response_std": self.response_std,
response_mean=self.response_mean, "response_mutate_strength": self.response_mutate_strength,
response_std=self.response_std, "response_mutate_rate": self.response_mutate_rate,
response_mutate_strength=self.response_mutate_strength, "response_replace_rate": self.response_replace_rate,
response_mutate_rate=self.response_mutate_rate, "weight_mean": self.weight_mean,
response_replace_rate=self.response_replace_rate, "weight_std": self.weight_std,
weight_mean=self.weight_mean, "weight_mutate_strength": self.weight_mutate_strength,
weight_std=self.weight_std, "weight_mutate_rate": self.weight_mutate_rate,
weight_mutate_strength=self.weight_mutate_strength, "weight_replace_rate": self.weight_replace_rate,
weight_mutate_rate=self.weight_mutate_rate, "act_default": self.act_default,
weight_replace_rate=self.weight_replace_rate, "act_list": self.act_list,
act_default=self.act_default, "act_replace_rate": self.act_replace_rate,
act_list=self.act_list, "agg_default": self.agg_default,
act_replace_rate=self.act_replace_rate, "agg_list": self.agg_list,
agg_default=self.agg_default, "agg_replace_rate": self.agg_replace_rate,
agg_list=self.agg_list, "enabled_reverse_rate": self.enabled_reverse_rate,
agg_replace_rate=self.agg_replace_rate, "add_node_rate": self.add_node_rate,
enabled_reverse_rate=self.enabled_reverse_rate, "delete_node_rate": self.delete_node_rate,
add_node_rate=self.add_node_rate, "add_connection_rate": self.add_connection_rate,
delete_node_rate=self.delete_node_rate, "delete_connection_rate": self.delete_connection_rate,
add_connection_rate=self.add_connection_rate, }
delete_connection_rate=self.delete_connection_rate,
single_structure_mutate=self.single_structure_mutate self.update_speciate_with_args = partial(
create_next_generation_then_speciate,
species_kwargs=species_kwargs,
mutate_kwargs=mutate_kwargs
) )
self.mutate_with_args = func
def compile_mutate(self, n, c): def create_update_speciate(self, N, C, S):
func = self.mutate_with_args key = ("update_speciate", N, C, S)
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, c, 4))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n, c)] = batched_mutate_func
def create_mutate(self, n, c):
key = ('mutate', n, c)
if key not in self.compiled_function: if key not in self.compiled_function:
self.compile_mutate(n, c) self.compile_update_speciate(N, C, S)
if self.debug:
def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_mutate
else:
return self.compiled_function[key] return self.compiled_function[key]
def create_distance_with_args(self): def compile_update_speciate(self, N, C, S):
func = partial( func = self.update_speciate_with_args
distance, randkey_lower = np.zeros((2,), dtype=np.uint32)
disjoint_coe=self.disjoint_coe, pop_nodes_lower = np.zeros((self.pop_size, N, 5))
compatibility_coe=self.compatibility_coe pop_cons_lower = np.zeros((self.pop_size, C, 4))
) winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
self.distance_with_args = func loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
def compile_distance(self, n, c): new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
func = self.distance_with_args pre_spe_center_nodes_lower = np.zeros((S, N, 5))
o2o_nodes1_lower = np.zeros((n, 5)) pre_spe_center_cons_lower = np.zeros((S, C, 4))
o2o_connections1_lower = np.zeros((c, 4)) species_keys = np.zeros((S,), dtype=np.int32)
o2o_nodes2_lower = np.zeros((n, 5)) new_species_keys_lower = 0
o2o_connections2_lower = np.zeros((c, 4)) compiled_func = jit(func).lower(
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, randkey_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile() pop_nodes_lower,
pop_cons_lower,
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) winner_part_lower,
o2m_connections2_lower = np.zeros((self.pop_size, c, 4)) loser_part_lower,
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, elite_mask_lower,
o2m_nodes2_lower, new_node_keys_start_lower,
o2m_connections2_lower).compile() pre_spe_center_nodes_lower,
pre_spe_center_cons_lower,
self.compiled_function[('o2o_distance', n, c)] = o2o_distance species_keys,
self.compiled_function[('o2m_distance', n, c)] = o2m_distance new_species_keys_lower,
).compile()
def create_distance(self, n, c): self.compiled_function[("update_speciate", N, C, S)] = compiled_func
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
if key1 not in self.compiled_function:
self.compile_distance(n, c)
if self.debug:
def debug_o2o_distance(*args):
return self.compiled_function[key1](*args).block_until_ready()
def debug_o2m_distance(*args):
return self.compiled_function[key2](*args).block_until_ready()
return debug_o2o_distance, debug_o2m_distance
else:
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n, c):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, c, 4))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, c, 4))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n, c)] = func
def create_crossover(self, n, c):
key = ('crossover', n, c)
if key not in self.compiled_function:
self.compile_crossover(n, c)
if self.debug:
def debug_crossover(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_crossover
else:
return self.compiled_function[key]
def create_topological_sort_with_args(self): def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort self.topological_sort_with_args = topological_sort

View File

@@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_
from .graph import check_cycles from .graph import check_cycles
@partial(jit, static_argnames=('single_structure_mutate',)) # TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array, def mutate(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, connections: Array,
@@ -44,7 +45,7 @@ def mutate(rand_key: Array,
delete_node_rate: float = 0.2, delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4, add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4, delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True): ):
""" """
:param output_idx: :param output_idx:
:param input_idx: :param input_idx:
@@ -78,44 +79,15 @@ def mutate(rand_key: Array,
:param delete_node_rate: :param delete_node_rate:
:param add_connection_rate: :param add_connection_rate:
:param delete_connection_rate: :param delete_connection_rate:
:param single_structure_mutate: a genome is structurally mutate at most once
:return: :return:
""" """
# mutate_structure
def nothing(rk, n, c):
return n, c
def m_add_node(rk, n, c): def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default) return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_add_connection(rk, n, c): def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx) return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
if single_structure_mutate:
r1, r2, rand_key = jax.random.split(rand_key, 3)
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
# shorten variable names for beauty
anr, dnr = add_node_rate / d, delete_node_rate / d
acr, dcr = add_connection_rate / d, delete_connection_rate / d
r = rand(r1)
branch = 0
branch = jnp.where(r <= anr, 1, branch)
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
else:
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5) r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node # mutate add node
@@ -123,21 +95,11 @@ def mutate(rand_key: Array,
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes) nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections) connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate add connection # mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections) aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes) nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections) connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength, nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std, bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate, response_mutate_strength, response_mutate_rate, response_replace_rate,

View File

@@ -1,109 +0,0 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def continue_execute_while(carry):
i, i2s, scn, scc = carry
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc)
)
return i + 1, i2s, scn, scc
def create_new_specie(carry):
i, i2s, scn, scc = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def update_exist_specie(carry):
i, i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# update new center
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def speciate_by_threshold(carry):
i, i2s, scn, scc = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
return i2s, scn, scc
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons

View File

@@ -7,8 +7,9 @@ import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import expand, expand_single from .genome import expand, expand_single
from .function_factory import FunctionFactory from .function_factory import FunctionFactory
from .genome.genome import count
from .genome.debug.tools import check_array_valid from .population import *
class Pipeline: class Pipeline:
""" """
@@ -17,7 +18,7 @@ class Pipeline:
def __init__(self, config, seed=42): def __init__(self, config, seed=42):
self.time_dict = {} self.time_dict = {}
self.function_factory = FunctionFactory(config, debug=True) self.function_factory = FunctionFactory(config)
self.randkey = jax.random.PRNGKey(seed) self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed) np.random.seed(seed)
@@ -25,17 +26,18 @@ class Pipeline:
self.config = config self.config = config
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
self.C = config.basic.init_maximum_connections self.C = config.basic.init_maximum_connections
self.S = config.basic.init_maximum_species
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config) self.species_controller = SpeciesController(config)
self.initialize_func = self.function_factory.create_initialize() self.initialize_func = self.function_factory.create_initialize()
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
self.compile_functions(debug=True) self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.generation = 0 self.generation = 0
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
self.best_genome = None self.best_genome = None
@@ -47,22 +49,26 @@ class Pipeline:
:return: :return:
Algorithm gives the population a forward function, then environment gives back the fitnesses. Algorithm gives the population a forward function, then environment gives back the fitnesses.
""" """
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections) return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons)
def tell(self, fitnesses): def tell(self, fitnesses):
self.generation += 1 self.generation += 1
self.species_controller.update_species_fitnesses(fitnesses) winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask(
fitnesses,
self.generation,
self.S, self.N, self.C)
winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation) new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate(
self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
self.update_next_generation(winner_part, loser_part, elite_mask) idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
self.o2o_distance, self.o2m_distance)
self.expand() self.expand()
@@ -86,49 +92,6 @@ class Pipeline:
print("Generation limit reached!") print("Generation limit reached!")
return self.best_genome return self.best_genome
def update_next_generation(self, winner_part, loser_part, elite_mask) -> None:
"""
create next generation
:param winner_part:
:param loser_part:
:param elite_mask:
:return:
"""
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
mutate_rand_keys = jax.random.split(k2, self.pop_size)
# batch crossover
wpn = self.pop_nodes[winner_part] # winner pop nodes
wpc = self.pop_connections[winner_part] # winner pop connections
lpn = self.pop_nodes[loser_part] # loser pop nodes
lpc = self.pop_connections[loser_part] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections
# for i in range(self.pop_size):
# n, c = np.array(npn[i]), np.array(npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# mutate
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# for i in range(self.pop_size):
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# elitism don't mutate
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
def expand(self): def expand(self):
""" """
Expand the population if needed. Expand the population if needed.
@@ -142,38 +105,29 @@ class Pipeline:
if max_node_size >= self.N: if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe) self.N = int(self.N * self.expand_coe)
print(f"node expand to {self.N}!") print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species # don't forget to expand representation genome in species
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C) s.representative = expand_single(*s.representative, self.N, self.C)
# update functions
self.compile_functions(debug=True)
pop_con_keys = self.pop_cons[:, :, 0]
pop_con_keys = self.pop_connections[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes) max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C: if max_con_size >= self.C:
self.C = int(self.C * self.expand_coe) self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!") print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C) self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species # don't forget to expand representation genome in species
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C) s.representative = expand_single(*s.representative, self.N, self.C)
# update functions self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.compile_functions(debug=True)
def compile_functions(self, debug=False):
self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
def default_analysis(self, fitnesses): def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()] species_sizes = [len(s.members) for s in self.species_controller.species.values()]
@@ -185,7 +139,7 @@ class Pipeline:
max_idx = np.argmax(fitnesses) max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness: if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx]) self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}", print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}") f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")

View File

@@ -0,0 +1,168 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
species_kwargs, mutate_kwargs):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys, **mutate_kwargs)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
pre_spe_center_cons, species_keys,
new_species_key_start, **species_kwargs)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
species_keys, new_species_key_start,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def find_new_centers(i, carry):
i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
return i2s, scn, scc
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
def continue_execute_while(carry):
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc, sk, ck = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, scn, scc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, scn, scc, sk, ck = carry
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck
def speciate_by_threshold(carry):
i, i2s, scn, scc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s, scn, scc
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
**mutate_kwargs):
# prepare functions
batch_crossover = vmap(crossover)
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn = pop_nodes[winner_part] # winner pop nodes
wpc = pop_cons[winner_part] # winner pop connections
lpn = pop_nodes[loser_part] # loser pop nodes
lpc = pop_cons[loser_part] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons

View File

@@ -5,6 +5,8 @@ import jax
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from .genome.utils import I_INT
class Species(object): class Species(object):
@@ -12,7 +14,7 @@ class Species(object):
self.key = key self.key = key
self.created = generation self.created = generation
self.last_improved = generation self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections) self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
self.members: NDArray = None # idx in pop_nodes, pop_connections, self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None self.fitness = None
self.member_fitnesses = None self.member_fitnesses = None
@@ -34,7 +36,7 @@ class SpeciesController:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.species_elitism = self.config.neat.species.species_elitism self.species_elitism = self.config.neat.species.species_elitism
self.pop_size = self.config.neat.population.pop_size self.pop_size = self.config.neat.population.pop_size
self.max_stagnation = self.config.neat.species.max_stagnation self.max_stagnation = self.config.neat.species.max_stagnation
@@ -59,97 +61,7 @@ class SpeciesController:
s.update((pop_nodes[0], pop_connections[0]), members) s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s self.species[species_id] = s
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int, def __update_species_fitnesses(self, fitnesses):
o2o_distance: Callable, o2m_distance: Callable) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:param o2o_distance: distance function for one-to-one comparison
:param o2m_distance: distance function for one-to-many comparison
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
previous_species_list = list(self.species.keys())
# Find the best representatives for each existing species.
new_representatives = {}
new_members = {}
total_distances = jax.device_get([
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
for sid in previous_species_list
])
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
new_members[sid] = [min_idx]
unspeciated[min_idx] = False
# Partition population into species based on genetic similarity.
# First, fast match the population to previous species
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = jax.device_get([
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
])
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]:
continue
min_idx = np.argmin(pop_res_distance[i])
min_val = pop_res_distance[i, min_idx]
if min_val <= self.compatibility_threshold:
species_id = previous_species_list[min_idx]
new_members[species_id].append(i)
unspeciated[i] = False
# Second, slowly match the lonely population to new-created species.s
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives.
for i in range(pop_nodes.shape[0]):
if not unspeciated[i]:
continue
unspeciated[i] = False
if len(new_representatives) != 0:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = jax.device_get([
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
])
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]
if min_val <= self.compatibility_threshold:
species_id = sid[min_idx]
new_members[species_id].append(i)
continue
# create a new species
species_id = next(self.species_idxer)
new_representatives[species_id] = i
new_members[species_id] = [i]
assert np.all(~unspeciated)
# Update species collection based on new speciation.
for sid, rid in new_representatives.items():
s = self.species.get(sid)
if s is None:
s = Species(sid, generation)
self.species[sid] = s
members = np.array(new_members[sid])
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
""" """
update the fitness of each species update the fitness of each species
:param fitnesses: :param fitnesses:
@@ -163,7 +75,7 @@ class SpeciesController:
s.fitness_history.append(s.fitness) s.fitness_history.append(s.fitness)
s.adjusted_fitness = None s.adjusted_fitness = None
def stagnation(self, generation): def __stagnation(self, generation):
""" """
code modified from neat-python! code modified from neat-python!
:param generation: :param generation:
@@ -196,7 +108,7 @@ class SpeciesController:
result.append((sid, s, is_stagnant)) result.append((sid, s, is_stagnant))
return result return result
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
""" """
code modified from neat-python! code modified from neat-python!
:param fitnesses: :param fitnesses:
@@ -215,7 +127,7 @@ class SpeciesController:
max_fitness = -np.inf max_fitness = -np.inf
remaining_species = [] remaining_species = []
for stag_sid, stag_s, stagnant in self.stagnation(generation): for stag_sid, stag_s, stagnant in self.__stagnation(generation):
if not stagnant: if not stagnant:
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses)) min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses)) max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
@@ -285,6 +197,33 @@ class SpeciesController:
return winner_part, loser_part, np.array(elite_mask) return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
s = Species(key, generation)
self.species[key] = s
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
def ask(self, fitnesses, generation, S, N, C):
self.__update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
pre_spe_center_cons = np.full((S, C, 4), np.nan)
species_keys = np.full((S,), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
pre_spe_center_nodes[idx] = specie.representative[0]
pre_spe_center_cons[idx] = specie.representative[1]
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
pre_spe_center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
""" """
@@ -326,12 +265,6 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
return spawn_amounts return spawn_amounts
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
masked_arr = np.where(mask, arr, np.inf)
min_idx = np.argmin(masked_arr)
return min_idx
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \ def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]: -> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1] sorted_idx = np.argsort(fitnesses)[::-1]

View File

@@ -4,9 +4,10 @@ from jax import jit, vmap
from time_utils import using_cprofile from time_utils import using_cprofile
from time import time from time import time
# #
import numpy as np
@jit @jit
def fx(x, y): def fx(x):
return x + y return jnp.arange(x, x + 10)
# #
# #
# # @jit # # @jit
@@ -33,13 +34,15 @@ def fx(x, y):
# @using_cprofile # @using_cprofile
def main(): def main():
vmap_f = vmap(fx, in_axes=(None, 0)) print(fx(1))
vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
a = jnp.array([20,10,30]) # vmap_f = vmap(fx, in_axes=(None, 0))
b = jnp.array([6, 5, 4]) # vmap_vmap_f = vmap(vmap_f, in_axes=(0, None))
res = vmap_vmap_f(a, b) # a = jnp.array([20,10,30])
print(res) # b = jnp.array([6, 5, 4])
print(jnp.argmin(res, axis=1)) # res = vmap_vmap_f(a, b)
# print(res)
# print(jnp.argmin(res, axis=1))

View File

@@ -4,7 +4,7 @@ import numpy as np
from algorithms.neat.function_factory import FunctionFactory from algorithms.neat.function_factory import FunctionFactory
from algorithms.neat.genome.debug.tools import check_array_valid from algorithms.neat.genome.debug.tools import check_array_valid
from utils import Configer from utils import Configer
from algorithms.neat.jitable_speciate import jitable_speciate from algorithms.neat.population import speciate
from algorithms.neat.genome.crossover import crossover from algorithms.neat.genome.crossover import crossover
from algorithms.neat.genome.utils import I_INT from algorithms.neat.genome.utils import I_INT
from time import time from time import time
@@ -23,7 +23,9 @@ if __name__ == '__main__':
spe_center_connections = np.full((species_size, C, 4), np.nan) spe_center_connections = np.full((species_size, C, 4), np.nan)
spe_center_nodes[0] = pop_nodes[0] spe_center_nodes[0] = pop_nodes[0]
spe_center_connections[0] = pop_connections[0] spe_center_connections[0] = pop_connections[0]
spe_keys = np.full((species_size,), I_INT)
spe_keys[0] = 0
new_spe_key = 1
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
new_node_idx = 100 new_node_idx = 100
@@ -43,25 +45,31 @@ if __name__ == '__main__':
n1, c1 = pop_nodes[idx1], pop_connections[idx1] n1, c1 = pop_nodes[idx1], pop_connections[idx1]
n2, c2 = pop_nodes[idx2], pop_connections[idx2] n2, c2 = pop_nodes[idx2], pop_connections[idx2]
crossover_keys = jax.random.split(subkey, len(pop_nodes)) crossover_keys = jax.random.split(subkey, len(pop_nodes))
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
# for i in range(len(pop_nodes)): # for i in range(len(pop_nodes)):
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx) # check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
#speciate next generation #speciate next generation
idx2specie, spe_center_nodes, spe_center_cons = jitable_speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections, idx2specie, spe_center_nodes, spe_center_cons, spe_keys, new_spe_key = speciate(pop_nodes, pop_connections, spe_center_nodes, spe_center_connections,
compatibility_threshold=2.5) spe_keys, new_spe_key,
compatibility_threshold=3)
idx2specie = np.array(idx2specie) print(spe_keys, new_spe_key)
spe_dict = {}
for i in range(len(idx2specie)):
spe_idx = idx2specie[i]
if spe_idx not in spe_dict:
spe_dict[spe_idx] = 1
else:
spe_dict[spe_idx] += 1
print(spe_dict) #
assert np.all(idx2specie != I_INT) # idx2specie = np.array(idx2specie)
# spe_dict = {}
# for i in range(len(idx2specie)):
# spe_idx = idx2specie[i]
# if spe_idx not in spe_dict:
# spe_dict[spe_idx] = 1
# else:
# spe_dict[spe_idx] += 1
#
# print(spe_dict)
# assert np.all(idx2specie != I_INT)
print(time() - start_time) print(time() - start_time)
# print(idx2specie) # print(idx2specie)

View File

@@ -12,7 +12,7 @@ def main():
config = Configer.load_config() config = Configer.load_config()
problem = Xor() problem = Xor()
problem.refactor_config(config) problem.refactor_config(config)
pipeline = Pipeline(config, seed=0) pipeline = Pipeline(config, seed=1)
pipeline.auto_run(problem.evaluate) pipeline.auto_run(problem.evaluate)

View File

@@ -3,8 +3,9 @@
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"problem_batch": 4, "problem_batch": 4,
"init_maximum_nodes": 50, "init_maximum_nodes": 20,
"init_maximum_connections": 50, "init_maximum_connections": 20,
"init_maximum_species": 10,
"expands_coe": 2, "expands_coe": 2,
"pre_compile_times": 3, "pre_compile_times": 3,
"forward_way": "pop_batch" "forward_way": "pop_batch"
@@ -14,7 +15,7 @@
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": -0.001, "fitness_threshold": -0.001,
"generation_limit": 1000, "generation_limit": 1000,
"pop_size": 10000, "pop_size": 1000,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -58,12 +59,12 @@
"compatibility_weight_coefficient": 0.5, "compatibility_weight_coefficient": 0.5,
"single_structural_mutation": "False", "single_structural_mutation": "False",
"conn_add_prob": 0.5, "conn_add_prob": 0.5,
"conn_delete_prob": 0.5, "conn_delete_prob": 0,
"node_add_prob": 0.2, "node_add_prob": 0.2,
"node_delete_prob": 0.2 "node_delete_prob": 0
}, },
"species": { "species": {
"compatibility_threshold": 2.5, "compatibility_threshold": 3,
"species_fitness_func": "max", "species_fitness_func": "max",
"max_stagnation": 20, "max_stagnation": 20,
"species_elitism": 2, "species_elitism": 2,