This commit is contained in:
wls2002
2023-05-13 20:58:03 +08:00
parent 90a9cc322d
commit 72c9d4167a
10 changed files with 372 additions and 529 deletions

View File

@@ -2,30 +2,34 @@
Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import time
import jax.random
import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import act_name2key, agg_name2key, initialize_genomes
from .genome import topological_sort, forward_single, unflatten_connections
from .population import create_next_generation_then_speciate
class FunctionFactory:
def __init__(self, config, debug=False):
def __init__(self, config):
self.config = config
self.debug = debug
self.init_N = config.basic.init_maximum_nodes
self.init_C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.time_cost = {}
self.load_config_vals(config)
self.precompile()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
self.create_update_speciate_with_args()
def load_config_vals(self, config):
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
@@ -79,12 +83,12 @@ class FunctionFactory:
self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self):
def create_initialize(self, N, C):
func = partial(
initialize_genomes,
pop_size=self.pop_size,
N=self.init_N,
C=self.init_C,
N=N,
C=C,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
@@ -93,166 +97,85 @@ class FunctionFactory:
default_agg=self.agg_default,
default_weight=self.weight_mean
)
if self.debug:
def debug_initialize(*args):
return func(*args)
return func
return debug_initialize
else:
return func
def create_update_speciate_with_args(self):
species_kwargs = {
"disjoint_coe": self.disjoint_coe,
"compatibility_coe": self.compatibility_coe,
"compatibility_threshold": self.compatibility_threshold
}
def precompile(self):
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
#
# n, c = self.init_N, self.init_C
# print("start precompile")
# for _ in range(self.precompile_times):
# self.compile_mutate(n)
# self.compile_distance(n)
# self.compile_crossover(n)
# self.compile_topological_sort_batch(n)
# self.compile_pop_batch_forward(n)
# n = int(self.expand_coe * n)
#
# # precompile other functions used in jax
# key = jax.random.PRNGKey(0)
# _ = jax.random.split(key, 3)
# _ = jax.random.split(key, self.pop_size * 2)
# _ = jax.random.split(key, self.pop_size)
#
# print("end precompile")
mutate_kwargs = {
"input_idx": self.input_idx,
"output_idx": self.output_idx,
"bias_mean": self.bias_mean,
"bias_std": self.bias_std,
"bias_mutate_strength": self.bias_mutate_strength,
"bias_mutate_rate": self.bias_mutate_rate,
"bias_replace_rate": self.bias_replace_rate,
"response_mean": self.response_mean,
"response_std": self.response_std,
"response_mutate_strength": self.response_mutate_strength,
"response_mutate_rate": self.response_mutate_rate,
"response_replace_rate": self.response_replace_rate,
"weight_mean": self.weight_mean,
"weight_std": self.weight_std,
"weight_mutate_strength": self.weight_mutate_strength,
"weight_mutate_rate": self.weight_mutate_rate,
"weight_replace_rate": self.weight_replace_rate,
"act_default": self.act_default,
"act_list": self.act_list,
"act_replace_rate": self.act_replace_rate,
"agg_default": self.agg_default,
"agg_list": self.agg_list,
"agg_replace_rate": self.agg_replace_rate,
"enabled_reverse_rate": self.enabled_reverse_rate,
"add_node_rate": self.add_node_rate,
"delete_node_rate": self.delete_node_rate,
"add_connection_rate": self.add_connection_rate,
"delete_connection_rate": self.delete_connection_rate,
}
def create_mutate_with_args(self):
func = partial(
mutate,
input_idx=self.input_idx,
output_idx=self.output_idx,
bias_mean=self.bias_mean,
bias_std=self.bias_std,
bias_mutate_strength=self.bias_mutate_strength,
bias_mutate_rate=self.bias_mutate_rate,
bias_replace_rate=self.bias_replace_rate,
response_mean=self.response_mean,
response_std=self.response_std,
response_mutate_strength=self.response_mutate_strength,
response_mutate_rate=self.response_mutate_rate,
response_replace_rate=self.response_replace_rate,
weight_mean=self.weight_mean,
weight_std=self.weight_std,
weight_mutate_strength=self.weight_mutate_strength,
weight_mutate_rate=self.weight_mutate_rate,
weight_replace_rate=self.weight_replace_rate,
act_default=self.act_default,
act_list=self.act_list,
act_replace_rate=self.act_replace_rate,
agg_default=self.agg_default,
agg_list=self.agg_list,
agg_replace_rate=self.agg_replace_rate,
enabled_reverse_rate=self.enabled_reverse_rate,
add_node_rate=self.add_node_rate,
delete_node_rate=self.delete_node_rate,
add_connection_rate=self.add_connection_rate,
delete_connection_rate=self.delete_connection_rate,
single_structure_mutate=self.single_structure_mutate
self.update_speciate_with_args = partial(
create_next_generation_then_speciate,
species_kwargs=species_kwargs,
mutate_kwargs=mutate_kwargs
)
self.mutate_with_args = func
def compile_mutate(self, n, c):
func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, c, 4))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n, c)] = batched_mutate_func
def create_mutate(self, n, c):
key = ('mutate', n, c)
def create_update_speciate(self, N, C, S):
key = ("update_speciate", N, C, S)
if key not in self.compiled_function:
self.compile_mutate(n, c)
if self.debug:
def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
self.compile_update_speciate(N, C, S)
return self.compiled_function[key]
return debug_mutate
else:
return self.compiled_function[key]
def create_distance_with_args(self):
func = partial(
distance,
disjoint_coe=self.disjoint_coe,
compatibility_coe=self.compatibility_coe
)
self.distance_with_args = func
def compile_distance(self, n, c):
func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((c, 4))
o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((c, 4))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower,
o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
def create_distance(self, n, c):
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
if key1 not in self.compiled_function:
self.compile_distance(n, c)
if self.debug:
def debug_o2o_distance(*args):
return self.compiled_function[key1](*args).block_until_ready()
def debug_o2m_distance(*args):
return self.compiled_function[key2](*args).block_until_ready()
return debug_o2o_distance, debug_o2m_distance
else:
return self.compiled_function[key1], self.compiled_function[key2]
def create_crossover_with_args(self):
self.crossover_with_args = crossover
def compile_crossover(self, n, c):
func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, c, 4))
nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, c, 4))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n, c)] = func
def create_crossover(self, n, c):
key = ('crossover', n, c)
if key not in self.compiled_function:
self.compile_crossover(n, c)
if self.debug:
def debug_crossover(*args):
res_nodes, res_connections = self.compiled_function[key](*args)
return res_nodes.block_until_ready(), res_connections.block_until_ready()
return debug_crossover
else:
return self.compiled_function[key]
def compile_update_speciate(self, N, C, S):
func = self.update_speciate_with_args
randkey_lower = np.zeros((2,), dtype=np.uint32)
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
pop_cons_lower = np.zeros((self.pop_size, C, 4))
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
pre_spe_center_cons_lower = np.zeros((S, C, 4))
species_keys = np.zeros((S,), dtype=np.int32)
new_species_keys_lower = 0
compiled_func = jit(func).lower(
randkey_lower,
pop_nodes_lower,
pop_cons_lower,
winner_part_lower,
loser_part_lower,
elite_mask_lower,
new_node_keys_start_lower,
pre_spe_center_nodes_lower,
pre_spe_center_cons_lower,
species_keys,
new_species_keys_lower,
).compile()
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort

View File

@@ -11,7 +11,8 @@ from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_
from .graph import check_cycles
@partial(jit, static_argnames=('single_structure_mutate',))
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
@@ -44,7 +45,7 @@ def mutate(rand_key: Array,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True):
):
"""
:param output_idx:
:param input_idx:
@@ -78,65 +79,26 @@ def mutate(rand_key: Array,
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param single_structure_mutate: a genome is structurally mutate at most once
:return:
"""
# mutate_structure
def nothing(rk, n, c):
return n, c
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
if single_structure_mutate:
r1, r2, rand_key = jax.random.split(rand_key, 3)
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
# shorten variable names for beauty
anr, dnr = add_node_rate / d, delete_node_rate / d
acr, dcr = add_connection_rate / d, delete_connection_rate / d
r = rand(r1)
branch = 0
branch = jnp.where(r <= anr, 1, branch)
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
else:
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
@@ -379,9 +341,9 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))

View File

@@ -1,109 +0,0 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def jitable_speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def continue_execute_while(carry):
i, i2s, scn, scc = carry
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc)
)
return i + 1, i2s, scn, scc
def create_new_specie(carry):
i, i2s, scn, scc = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def update_exist_specie(carry):
i, i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# update new center
i2s = i2s.at[idx].set(i)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc))
return i2s, scn, scc
def speciate_by_threshold(carry):
i, i2s, scn, scc = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), i, i2s)
return i2s, scn, scc
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_size - 1, idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons

View File

@@ -7,8 +7,9 @@ import numpy as np
from .species import SpeciesController
from .genome import expand, expand_single
from .function_factory import FunctionFactory
from .genome.genome import count
from .genome.debug.tools import check_array_valid
from .population import *
class Pipeline:
"""
@@ -17,7 +18,7 @@ class Pipeline:
def __init__(self, config, seed=42):
self.time_dict = {}
self.function_factory = FunctionFactory(config, debug=True)
self.function_factory = FunctionFactory(config)
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
@@ -25,17 +26,18 @@ class Pipeline:
self.config = config
self.N = config.basic.init_maximum_nodes
self.C = config.basic.init_maximum_connections
self.S = config.basic.init_maximum_species
self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config)
self.initialize_func = self.function_factory.create_initialize()
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
self.compile_functions(debug=True)
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.generation = 0
self.species_controller.init_speciate(self.pop_nodes, self.pop_connections)
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
@@ -47,22 +49,26 @@ class Pipeline:
:return:
Algorithm gives the population a forward function, then environment gives back the fitnesses.
"""
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_connections)
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons)
def tell(self, fitnesses):
self.generation += 1
self.species_controller.update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask(
fitnesses,
self.generation,
self.S, self.N, self.C)
winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation)
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate(
self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
self.update_next_generation(winner_part, loser_part, elite_mask)
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
self.o2o_distance, self.o2m_distance)
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
self.expand()
@@ -86,49 +92,6 @@ class Pipeline:
print("Generation limit reached!")
return self.best_genome
def update_next_generation(self, winner_part, loser_part, elite_mask) -> None:
"""
create next generation
:param winner_part:
:param loser_part:
:param elite_mask:
:return:
"""
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
mutate_rand_keys = jax.random.split(k2, self.pop_size)
# batch crossover
wpn = self.pop_nodes[winner_part] # winner pop nodes
wpc = self.pop_connections[winner_part] # winner pop connections
lpn = self.pop_nodes[loser_part] # loser pop nodes
lpc = self.pop_connections[loser_part] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections
# for i in range(self.pop_size):
# n, c = np.array(npn[i]), np.array(npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# mutate
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# for i in range(self.pop_size):
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# elitism don't mutate
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
def expand(self):
"""
Expand the population if needed.
@@ -142,37 +105,28 @@ class Pipeline:
if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe)
print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C)
# update functions
self.compile_functions(debug=True)
pop_con_keys = self.pop_connections[:, :, 0]
pop_con_keys = self.pop_cons[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C:
self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C)
# update functions
self.compile_functions(debug=True)
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
def compile_functions(self, debug=False):
self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
@@ -185,7 +139,7 @@ class Pipeline:
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_connections[max_idx])
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")

View File

@@ -0,0 +1,168 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
species_kwargs, mutate_kwargs):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys, **mutate_kwargs)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
pre_spe_center_cons, species_keys,
new_species_key_start, **species_kwargs)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
species_keys, new_species_key_start,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def find_new_centers(i, carry):
i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
return i2s, scn, scc
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
def continue_execute_while(carry):
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc, sk, ck = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, scn, scc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, scn, scc, sk, ck = carry
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck
def speciate_by_threshold(carry):
i, i2s, scn, scc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s, scn, scc
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
**mutate_kwargs):
# prepare functions
batch_crossover = vmap(crossover)
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn = pop_nodes[winner_part] # winner pop nodes
wpc = pop_cons[winner_part] # winner pop connections
lpn = pop_nodes[loser_part] # loser pop nodes
lpc = pop_cons[loser_part] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons

View File

@@ -5,6 +5,8 @@ import jax
import numpy as np
from numpy.typing import NDArray
from .genome.utils import I_INT
class Species(object):
@@ -12,7 +14,7 @@ class Species(object):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None
self.member_fitnesses = None
@@ -34,7 +36,7 @@ class SpeciesController:
def __init__(self, config):
self.config = config
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.species_elitism = self.config.neat.species.species_elitism
self.pop_size = self.config.neat.population.pop_size
self.max_stagnation = self.config.neat.species.max_stagnation
@@ -59,97 +61,7 @@ class SpeciesController:
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int,
o2o_distance: Callable, o2m_distance: Callable) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:param o2o_distance: distance function for one-to-one comparison
:param o2m_distance: distance function for one-to-many comparison
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
previous_species_list = list(self.species.keys())
# Find the best representatives for each existing species.
new_representatives = {}
new_members = {}
total_distances = jax.device_get([
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
for sid in previous_species_list
])
# TODO: Use jit to wrapper function find_min_with_mask to accelerate this part
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
new_members[sid] = [min_idx]
unspeciated[min_idx] = False
# Partition population into species based on genetic similarity.
# First, fast match the population to previous species
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = jax.device_get([
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
])
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]:
continue
min_idx = np.argmin(pop_res_distance[i])
min_val = pop_res_distance[i, min_idx]
if min_val <= self.compatibility_threshold:
species_id = previous_species_list[min_idx]
new_members[species_id].append(i)
unspeciated[i] = False
# Second, slowly match the lonely population to new-created species.s
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives.
for i in range(pop_nodes.shape[0]):
if not unspeciated[i]:
continue
unspeciated[i] = False
if len(new_representatives) != 0:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = jax.device_get([
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
])
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]
if min_val <= self.compatibility_threshold:
species_id = sid[min_idx]
new_members[species_id].append(i)
continue
# create a new species
species_id = next(self.species_idxer)
new_representatives[species_id] = i
new_members[species_id] = [i]
assert np.all(~unspeciated)
# Update species collection based on new speciation.
for sid, rid in new_representatives.items():
s = self.species.get(sid)
if s is None:
s = Species(sid, generation)
self.species[sid] = s
members = np.array(new_members[sid])
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
def __update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
@@ -163,7 +75,7 @@ class SpeciesController:
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def stagnation(self, generation):
def __stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
@@ -196,7 +108,7 @@ class SpeciesController:
result.append((sid, s, is_stagnant))
return result
def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
code modified from neat-python!
:param fitnesses:
@@ -215,7 +127,7 @@ class SpeciesController:
max_fitness = -np.inf
remaining_species = []
for stag_sid, stag_s, stagnant in self.stagnation(generation):
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
if not stagnant:
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
@@ -285,6 +197,33 @@ class SpeciesController:
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
s = Species(key, generation)
self.species[key] = s
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
def ask(self, fitnesses, generation, S, N, C):
self.__update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
pre_spe_center_cons = np.full((S, C, 4), np.nan)
species_keys = np.full((S,), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
pre_spe_center_nodes[idx] = specie.representative[0]
pre_spe_center_cons[idx] = specie.representative[1]
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
pre_spe_center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
@@ -326,13 +265,7 @@ def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
return spawn_amounts
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
masked_arr = np.where(mask, arr, np.inf)
min_idx = np.argmin(masked_arr)
return min_idx
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1]
return members[sorted_idx], fitnesses[sorted_idx]
return members[sorted_idx], fitnesses[sorted_idx]