FAST!
This commit is contained in:
@@ -2,30 +2,34 @@
|
||||
Lowers, compiles, and creates functions used in the NEAT pipeline.
|
||||
"""
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import jax.random
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
|
||||
from .genome import act_name2key, agg_name2key, initialize_genomes
|
||||
from .genome import topological_sort, forward_single, unflatten_connections
|
||||
from .population import create_next_generation_then_speciate
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
def __init__(self, config, debug=False):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.debug = debug
|
||||
|
||||
self.init_N = config.basic.init_maximum_nodes
|
||||
self.init_C = config.basic.init_maximum_connections
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
self.time_cost = {}
|
||||
|
||||
self.load_config_vals(config)
|
||||
self.precompile()
|
||||
|
||||
self.create_topological_sort_with_args()
|
||||
self.create_single_forward_with_args()
|
||||
self.create_update_speciate_with_args()
|
||||
|
||||
def load_config_vals(self, config):
|
||||
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||
|
||||
self.problem_batch = config.basic.problem_batch
|
||||
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
@@ -79,12 +83,12 @@ class FunctionFactory:
|
||||
self.delete_connection_rate = genome.conn_delete_prob
|
||||
self.single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def create_initialize(self):
|
||||
def create_initialize(self, N, C):
|
||||
func = partial(
|
||||
initialize_genomes,
|
||||
pop_size=self.pop_size,
|
||||
N=self.init_N,
|
||||
C=self.init_C,
|
||||
N=N,
|
||||
C=C,
|
||||
num_inputs=self.num_inputs,
|
||||
num_outputs=self.num_outputs,
|
||||
default_bias=self.bias_mean,
|
||||
@@ -93,166 +97,85 @@ class FunctionFactory:
|
||||
default_agg=self.agg_default,
|
||||
default_weight=self.weight_mean
|
||||
)
|
||||
if self.debug:
|
||||
def debug_initialize(*args):
|
||||
return func(*args)
|
||||
return func
|
||||
|
||||
return debug_initialize
|
||||
else:
|
||||
return func
|
||||
def create_update_speciate_with_args(self):
|
||||
species_kwargs = {
|
||||
"disjoint_coe": self.disjoint_coe,
|
||||
"compatibility_coe": self.compatibility_coe,
|
||||
"compatibility_threshold": self.compatibility_threshold
|
||||
}
|
||||
|
||||
def precompile(self):
|
||||
self.create_mutate_with_args()
|
||||
self.create_distance_with_args()
|
||||
self.create_crossover_with_args()
|
||||
self.create_topological_sort_with_args()
|
||||
self.create_single_forward_with_args()
|
||||
#
|
||||
# n, c = self.init_N, self.init_C
|
||||
# print("start precompile")
|
||||
# for _ in range(self.precompile_times):
|
||||
# self.compile_mutate(n)
|
||||
# self.compile_distance(n)
|
||||
# self.compile_crossover(n)
|
||||
# self.compile_topological_sort_batch(n)
|
||||
# self.compile_pop_batch_forward(n)
|
||||
# n = int(self.expand_coe * n)
|
||||
#
|
||||
# # precompile other functions used in jax
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# _ = jax.random.split(key, 3)
|
||||
# _ = jax.random.split(key, self.pop_size * 2)
|
||||
# _ = jax.random.split(key, self.pop_size)
|
||||
#
|
||||
# print("end precompile")
|
||||
mutate_kwargs = {
|
||||
"input_idx": self.input_idx,
|
||||
"output_idx": self.output_idx,
|
||||
"bias_mean": self.bias_mean,
|
||||
"bias_std": self.bias_std,
|
||||
"bias_mutate_strength": self.bias_mutate_strength,
|
||||
"bias_mutate_rate": self.bias_mutate_rate,
|
||||
"bias_replace_rate": self.bias_replace_rate,
|
||||
"response_mean": self.response_mean,
|
||||
"response_std": self.response_std,
|
||||
"response_mutate_strength": self.response_mutate_strength,
|
||||
"response_mutate_rate": self.response_mutate_rate,
|
||||
"response_replace_rate": self.response_replace_rate,
|
||||
"weight_mean": self.weight_mean,
|
||||
"weight_std": self.weight_std,
|
||||
"weight_mutate_strength": self.weight_mutate_strength,
|
||||
"weight_mutate_rate": self.weight_mutate_rate,
|
||||
"weight_replace_rate": self.weight_replace_rate,
|
||||
"act_default": self.act_default,
|
||||
"act_list": self.act_list,
|
||||
"act_replace_rate": self.act_replace_rate,
|
||||
"agg_default": self.agg_default,
|
||||
"agg_list": self.agg_list,
|
||||
"agg_replace_rate": self.agg_replace_rate,
|
||||
"enabled_reverse_rate": self.enabled_reverse_rate,
|
||||
"add_node_rate": self.add_node_rate,
|
||||
"delete_node_rate": self.delete_node_rate,
|
||||
"add_connection_rate": self.add_connection_rate,
|
||||
"delete_connection_rate": self.delete_connection_rate,
|
||||
}
|
||||
|
||||
def create_mutate_with_args(self):
|
||||
func = partial(
|
||||
mutate,
|
||||
input_idx=self.input_idx,
|
||||
output_idx=self.output_idx,
|
||||
bias_mean=self.bias_mean,
|
||||
bias_std=self.bias_std,
|
||||
bias_mutate_strength=self.bias_mutate_strength,
|
||||
bias_mutate_rate=self.bias_mutate_rate,
|
||||
bias_replace_rate=self.bias_replace_rate,
|
||||
response_mean=self.response_mean,
|
||||
response_std=self.response_std,
|
||||
response_mutate_strength=self.response_mutate_strength,
|
||||
response_mutate_rate=self.response_mutate_rate,
|
||||
response_replace_rate=self.response_replace_rate,
|
||||
weight_mean=self.weight_mean,
|
||||
weight_std=self.weight_std,
|
||||
weight_mutate_strength=self.weight_mutate_strength,
|
||||
weight_mutate_rate=self.weight_mutate_rate,
|
||||
weight_replace_rate=self.weight_replace_rate,
|
||||
act_default=self.act_default,
|
||||
act_list=self.act_list,
|
||||
act_replace_rate=self.act_replace_rate,
|
||||
agg_default=self.agg_default,
|
||||
agg_list=self.agg_list,
|
||||
agg_replace_rate=self.agg_replace_rate,
|
||||
enabled_reverse_rate=self.enabled_reverse_rate,
|
||||
add_node_rate=self.add_node_rate,
|
||||
delete_node_rate=self.delete_node_rate,
|
||||
add_connection_rate=self.add_connection_rate,
|
||||
delete_connection_rate=self.delete_connection_rate,
|
||||
single_structure_mutate=self.single_structure_mutate
|
||||
self.update_speciate_with_args = partial(
|
||||
create_next_generation_then_speciate,
|
||||
species_kwargs=species_kwargs,
|
||||
mutate_kwargs=mutate_kwargs
|
||||
)
|
||||
self.mutate_with_args = func
|
||||
|
||||
def compile_mutate(self, n, c):
|
||||
func = self.mutate_with_args
|
||||
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, c, 4))
|
||||
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
|
||||
connections_lower, new_node_key_lower).compile()
|
||||
self.compiled_function[('mutate', n, c)] = batched_mutate_func
|
||||
|
||||
def create_mutate(self, n, c):
|
||||
key = ('mutate', n, c)
|
||||
def create_update_speciate(self, N, C, S):
|
||||
key = ("update_speciate", N, C, S)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_mutate(n, c)
|
||||
if self.debug:
|
||||
def debug_mutate(*args):
|
||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||
self.compile_update_speciate(N, C, S)
|
||||
return self.compiled_function[key]
|
||||
|
||||
return debug_mutate
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_distance_with_args(self):
|
||||
func = partial(
|
||||
distance,
|
||||
disjoint_coe=self.disjoint_coe,
|
||||
compatibility_coe=self.compatibility_coe
|
||||
)
|
||||
self.distance_with_args = func
|
||||
|
||||
def compile_distance(self, n, c):
|
||||
func = self.distance_with_args
|
||||
o2o_nodes1_lower = np.zeros((n, 5))
|
||||
o2o_connections1_lower = np.zeros((c, 4))
|
||||
o2o_nodes2_lower = np.zeros((n, 5))
|
||||
o2o_connections2_lower = np.zeros((c, 4))
|
||||
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2o_nodes2_lower, o2o_connections2_lower).compile()
|
||||
|
||||
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2m_nodes2_lower,
|
||||
o2m_connections2_lower).compile()
|
||||
|
||||
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
|
||||
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
|
||||
|
||||
def create_distance(self, n, c):
|
||||
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
|
||||
if key1 not in self.compiled_function:
|
||||
self.compile_distance(n, c)
|
||||
if self.debug:
|
||||
|
||||
def debug_o2o_distance(*args):
|
||||
return self.compiled_function[key1](*args).block_until_ready()
|
||||
|
||||
def debug_o2m_distance(*args):
|
||||
return self.compiled_function[key2](*args).block_until_ready()
|
||||
|
||||
return debug_o2o_distance, debug_o2m_distance
|
||||
else:
|
||||
return self.compiled_function[key1], self.compiled_function[key2]
|
||||
|
||||
def create_crossover_with_args(self):
|
||||
self.crossover_with_args = crossover
|
||||
|
||||
def compile_crossover(self, n, c):
|
||||
func = self.crossover_with_args
|
||||
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes1_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections1_lower = np.zeros((self.pop_size, c, 4))
|
||||
nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
self.compiled_function[('crossover', n, c)] = func
|
||||
|
||||
def create_crossover(self, n, c):
|
||||
key = ('crossover', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_crossover(n, c)
|
||||
if self.debug:
|
||||
|
||||
def debug_crossover(*args):
|
||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||
return res_nodes.block_until_ready(), res_connections.block_until_ready()
|
||||
|
||||
return debug_crossover
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
def compile_update_speciate(self, N, C, S):
|
||||
func = self.update_speciate_with_args
|
||||
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
||||
pop_cons_lower = np.zeros((self.pop_size, C, 4))
|
||||
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
|
||||
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
|
||||
pre_spe_center_cons_lower = np.zeros((S, C, 4))
|
||||
species_keys = np.zeros((S,), dtype=np.int32)
|
||||
new_species_keys_lower = 0
|
||||
compiled_func = jit(func).lower(
|
||||
randkey_lower,
|
||||
pop_nodes_lower,
|
||||
pop_cons_lower,
|
||||
winner_part_lower,
|
||||
loser_part_lower,
|
||||
elite_mask_lower,
|
||||
new_node_keys_start_lower,
|
||||
pre_spe_center_nodes_lower,
|
||||
pre_spe_center_cons_lower,
|
||||
species_keys,
|
||||
new_species_keys_lower,
|
||||
).compile()
|
||||
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
||||
|
||||
def create_topological_sort_with_args(self):
|
||||
self.topological_sort_with_args = topological_sort
|
||||
|
||||
@@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_
|
||||
from .graph import check_cycles
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('single_structure_mutate',))
|
||||
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
|
||||
@jit
|
||||
def mutate(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
@@ -44,7 +45,7 @@ def mutate(rand_key: Array,
|
||||
delete_node_rate: float = 0.2,
|
||||
add_connection_rate: float = 0.4,
|
||||
delete_connection_rate: float = 0.4,
|
||||
single_structure_mutate: bool = True):
|
||||
):
|
||||
"""
|
||||
:param output_idx:
|
||||
:param input_idx:
|
||||
@@ -78,65 +79,26 @@ def mutate(rand_key: Array,
|
||||
:param delete_node_rate:
|
||||
:param add_connection_rate:
|
||||
:param delete_connection_rate:
|
||||
:param single_structure_mutate: a genome is structurally mutate at most once
|
||||
:return:
|
||||
"""
|
||||
|
||||
# mutate_structure
|
||||
def nothing(rk, n, c):
|
||||
return n, c
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
|
||||
# mutate add node
|
||||
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||
|
||||
if single_structure_mutate:
|
||||
r1, r2, rand_key = jax.random.split(rand_key, 3)
|
||||
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
|
||||
|
||||
# shorten variable names for beauty
|
||||
anr, dnr = add_node_rate / d, delete_node_rate / d
|
||||
acr, dcr = add_connection_rate / d, delete_connection_rate / d
|
||||
|
||||
r = rand(r1)
|
||||
branch = 0
|
||||
branch = jnp.where(r <= anr, 1, branch)
|
||||
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
|
||||
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
|
||||
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
|
||||
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
|
||||
else:
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# mutate add node
|
||||
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||
|
||||
# mutate add connection
|
||||
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||
# mutate add connection
|
||||
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||
@@ -379,9 +341,9 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
|
||||
from .genome import distance
|
||||
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
||||
|
||||
|
||||
@jit
|
||||
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
||||
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
||||
):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
||||
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None)
|
||||
)
|
||||
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
||||
|
||||
def continue_execute_while(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def deal_with_each_center_genome(carry):
|
||||
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
||||
|
||||
i2s, scn, scc = jax.lax.cond(
|
||||
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
||||
create_new_specie, # if not valid, create a new specie
|
||||
update_exist_specie, # if valid, update the specie
|
||||
(i, i2s, scn, scc)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to new specie
|
||||
i2s = i2s.at[idx].set(i)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||
return i2s, scn, scc
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# update new center
|
||||
i2s = i2s.at[idx].set(i)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
|
||||
return i2s, scn, scc
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, scn, scc = carry
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
||||
close_enough_mask = o2p_distance < compatibility_threshold
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
|
||||
return i2s, scn, scc
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
|
||||
continue_execute_while,
|
||||
deal_with_each_center_genome,
|
||||
(0, idx2specie, spe_center_nodes, spe_center_cons)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
|
||||
|
||||
return idx2specie, spe_center_nodes, spe_center_cons
|
||||
@@ -7,8 +7,9 @@ import numpy as np
|
||||
from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from .function_factory import FunctionFactory
|
||||
from .genome.genome import count
|
||||
from .genome.debug.tools import check_array_valid
|
||||
|
||||
from .population import *
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
@@ -17,7 +18,7 @@ class Pipeline:
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.time_dict = {}
|
||||
self.function_factory = FunctionFactory(config, debug=True)
|
||||
self.function_factory = FunctionFactory(config)
|
||||
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -25,17 +26,18 @@ class Pipeline:
|
||||
self.config = config
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
self.C = config.basic.init_maximum_connections
|
||||
self.S = config.basic.init_maximum_species
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = self.function_factory.create_initialize()
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
|
||||
|
||||
self.compile_functions(debug=True)
|
||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||
|
||||
self.generation = 0
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
@@ -47,22 +49,26 @@ class Pipeline:
|
||||
:return:
|
||||
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
||||
"""
|
||||
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections)
|
||||
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
|
||||
self.generation += 1
|
||||
|
||||
self.species_controller.update_species_fitnesses(fitnesses)
|
||||
winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask(
|
||||
fitnesses,
|
||||
self.generation,
|
||||
self.S, self.N, self.C)
|
||||
|
||||
winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation)
|
||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate(
|
||||
self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys,
|
||||
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
|
||||
|
||||
self.update_next_generation(winner_part, loser_part, elite_mask)
|
||||
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
||||
|
||||
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
||||
self.o2o_distance, self.o2m_distance)
|
||||
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
|
||||
|
||||
self.expand()
|
||||
|
||||
@@ -86,49 +92,6 @@ class Pipeline:
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def update_next_generation(self, winner_part, loser_part, elite_mask) -> None:
|
||||
"""
|
||||
create next generation
|
||||
:param winner_part:
|
||||
:param loser_part:
|
||||
:param elite_mask:
|
||||
:return:
|
||||
"""
|
||||
|
||||
assert self.pop_nodes.shape[0] == self.pop_size
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn = self.pop_nodes[winner_part] # winner pop nodes
|
||||
wpc = self.pop_connections[winner_part] # winner pop connections
|
||||
lpn = self.pop_nodes[loser_part] # loser pop nodes
|
||||
lpc = self.pop_connections[loser_part] # loser pop connections
|
||||
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||
lpc) # new pop nodes, new pop connections
|
||||
|
||||
# for i in range(self.pop_size):
|
||||
# n, c = np.array(npn[i]), np.array(npc[i])
|
||||
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||
|
||||
# mutate
|
||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# for i in range(self.pop_size):
|
||||
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
|
||||
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||
|
||||
# elitism don't mutate
|
||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||
|
||||
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
|
||||
self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
@@ -142,37 +105,28 @@ class Pipeline:
|
||||
if max_node_size >= self.N:
|
||||
self.N = int(self.N * self.expand_coe)
|
||||
print(f"node expand to {self.N}!")
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
# update functions
|
||||
self.compile_functions(debug=True)
|
||||
|
||||
|
||||
pop_con_keys = self.pop_connections[:, :, 0]
|
||||
pop_con_keys = self.pop_cons[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.C:
|
||||
self.C = int(self.C * self.expand_coe)
|
||||
print(f"connections expand to {self.C}!")
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
# update functions
|
||||
self.compile_functions(debug=True)
|
||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||
|
||||
|
||||
|
||||
def compile_functions(self, debug=False):
|
||||
self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
|
||||
self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
|
||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
|
||||
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
@@ -185,7 +139,7 @@ class Pipeline:
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx])
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
|
||||
168
algorithms/neat/population.py
Normal file
168
algorithms/neat/population.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys,
|
||||
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
|
||||
species_kwargs, mutate_kwargs):
|
||||
# create next generation
|
||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys, **mutate_kwargs)
|
||||
|
||||
# speciate
|
||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
|
||||
pre_spe_center_cons, species_keys,
|
||||
new_species_key_start, **species_kwargs)
|
||||
|
||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
||||
species_keys, new_species_key_start,
|
||||
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
||||
):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
||||
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None)
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, scn, scc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_keys[i])
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
return i2s, scn, scc
|
||||
|
||||
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
|
||||
|
||||
def continue_execute_while(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def deal_with_each_center_genome(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
||||
|
||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
||||
create_new_specie, # if not valid, create a new specie
|
||||
update_exist_specie, # if valid, update the specie
|
||||
(i, i2s, scn, scc, sk, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, sk, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to new specie
|
||||
sk = sk.at[i].set(ck)
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
||||
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
||||
return i2s, scn, scc, sk, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, scn, scc, sk = carry
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
||||
close_enough_mask = o2p_distance < compatibility_threshold
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||
return i2s, scn, scc
|
||||
|
||||
current_new_key = new_species_key_start
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
|
||||
continue_execute_while,
|
||||
deal_with_each_center_genome,
|
||||
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||
|
||||
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
|
||||
**mutate_kwargs):
|
||||
# prepare functions
|
||||
batch_crossover = vmap(crossover)
|
||||
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
|
||||
|
||||
pop_size = pop_nodes.shape[0]
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn = pop_nodes[winner_part] # winner pop nodes
|
||||
wpc = pop_cons[winner_part] # winner pop connections
|
||||
lpn = pop_nodes[loser_part] # loser pop nodes
|
||||
lpc = pop_cons[loser_part] # loser pop connections
|
||||
|
||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
@@ -5,6 +5,8 @@ import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome.utils import I_INT
|
||||
|
||||
|
||||
class Species(object):
|
||||
|
||||
@@ -12,7 +14,7 @@ class Species(object):
|
||||
self.key = key
|
||||
self.created = generation
|
||||
self.last_improved = generation
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
|
||||
self.members: NDArray = None # idx in pop_nodes, pop_connections,
|
||||
self.fitness = None
|
||||
self.member_fitnesses = None
|
||||
@@ -34,7 +36,7 @@ class SpeciesController:
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||
|
||||
self.species_elitism = self.config.neat.species.species_elitism
|
||||
self.pop_size = self.config.neat.population.pop_size
|
||||
self.max_stagnation = self.config.neat.species.max_stagnation
|
||||
@@ -59,97 +61,7 @@ class SpeciesController:
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int,
|
||||
o2o_distance: Callable, o2m_distance: Callable) -> None:
|
||||
"""
|
||||
:param pop_nodes:
|
||||
:param pop_connections:
|
||||
:param generation: use to flag the created time of new species
|
||||
:param o2o_distance: distance function for one-to-one comparison
|
||||
:param o2m_distance: distance function for one-to-many comparison
|
||||
:return:
|
||||
"""
|
||||
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
|
||||
previous_species_list = list(self.species.keys())
|
||||
|
||||
# Find the best representatives for each existing species.
|
||||
new_representatives = {}
|
||||
new_members = {}
|
||||
|
||||
total_distances = jax.device_get([
|
||||
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
|
||||
for sid in previous_species_list
|
||||
])
|
||||
|
||||
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
|
||||
for i, sid in enumerate(previous_species_list):
|
||||
distances = total_distances[i]
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
new_representatives[sid] = min_idx
|
||||
new_members[sid] = [min_idx]
|
||||
unspeciated[min_idx] = False
|
||||
|
||||
# Partition population into species based on genetic similarity.
|
||||
|
||||
# First, fast match the population to previous species
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = jax.device_get([
|
||||
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
for rid in rid_list
|
||||
])
|
||||
|
||||
pop_res_distance = np.stack(res_pop_distance, axis=0).T
|
||||
for i in range(pop_res_distance.shape[0]):
|
||||
if not unspeciated[i]:
|
||||
continue
|
||||
min_idx = np.argmin(pop_res_distance[i])
|
||||
min_val = pop_res_distance[i, min_idx]
|
||||
if min_val <= self.compatibility_threshold:
|
||||
species_id = previous_species_list[min_idx]
|
||||
new_members[species_id].append(i)
|
||||
unspeciated[i] = False
|
||||
|
||||
# Second, slowly match the lonely population to new-created species.s
|
||||
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
|
||||
# the new representatives.
|
||||
for i in range(pop_nodes.shape[0]):
|
||||
if not unspeciated[i]:
|
||||
continue
|
||||
unspeciated[i] = False
|
||||
if len(new_representatives) != 0:
|
||||
# the representatives of new species
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
distances = jax.device_get([
|
||||
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
])
|
||||
distances = np.array(distances)
|
||||
min_idx = np.argmin(distances)
|
||||
min_val = distances[min_idx]
|
||||
if min_val <= self.compatibility_threshold:
|
||||
species_id = sid[min_idx]
|
||||
new_members[species_id].append(i)
|
||||
continue
|
||||
# create a new species
|
||||
species_id = next(self.species_idxer)
|
||||
new_representatives[species_id] = i
|
||||
new_members[species_id] = [i]
|
||||
|
||||
assert np.all(~unspeciated)
|
||||
|
||||
# Update species collection based on new speciation.
|
||||
for sid, rid in new_representatives.items():
|
||||
s = self.species.get(sid)
|
||||
if s is None:
|
||||
s = Species(sid, generation)
|
||||
self.species[sid] = s
|
||||
|
||||
members = np.array(new_members[sid])
|
||||
s.update((pop_nodes[rid], pop_connections[rid]), members)
|
||||
|
||||
def update_species_fitnesses(self, fitnesses):
|
||||
def __update_species_fitnesses(self, fitnesses):
|
||||
"""
|
||||
update the fitness of each species
|
||||
:param fitnesses:
|
||||
@@ -163,7 +75,7 @@ class SpeciesController:
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def stagnation(self, generation):
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param generation:
|
||||
@@ -196,7 +108,7 @@ class SpeciesController:
|
||||
result.append((sid, s, is_stagnant))
|
||||
return result
|
||||
|
||||
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param fitnesses:
|
||||
@@ -215,7 +127,7 @@ class SpeciesController:
|
||||
max_fitness = -np.inf
|
||||
|
||||
remaining_species = []
|
||||
for stag_sid, stag_s, stagnant in self.stagnation(generation):
|
||||
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
|
||||
if not stagnant:
|
||||
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
|
||||
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
|
||||
@@ -285,6 +197,33 @@ class SpeciesController:
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
if key not in self.species:
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, S, N, C):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
|
||||
pre_spe_center_cons = np.full((S, C, 4), np.nan)
|
||||
species_keys = np.full((S,), I_INT)
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
pre_spe_center_nodes[idx] = specie.representative[0]
|
||||
pre_spe_center_cons[idx] = specie.representative[1]
|
||||
species_keys[idx] = key
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
|
||||
pre_spe_center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
"""
|
||||
@@ -326,13 +265,7 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
return spawn_amounts
|
||||
|
||||
|
||||
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
|
||||
masked_arr = np.where(mask, arr, np.inf)
|
||||
min_idx = np.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
|
||||
-> Tuple[NDArray, NDArray]:
|
||||
sorted_idx = np.argsort(fitnesses)[::-1]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
|
||||
Reference in New Issue
Block a user