Next is to connect with Evox!
This commit is contained in:
wls2002
2023-06-25 02:57:45 +08:00
parent 0cb2f9473d
commit ba369db0b2
14 changed files with 392 additions and 268 deletions

View File

@@ -1,10 +1,10 @@
[basic]
num_inputs = 2
num_outputs = 1
init_maximum_nodes = 20
init_maximum_connections = 20
init_maximum_nodes = 50
init_maximum_connections = 50
init_maximum_species = 10
expands_coe = 2.0
expand_coe = 2.0
forward_way = "pop"
batch_size = 4
@@ -12,7 +12,7 @@ batch_size = 4
fitness_threshold = 100000
generation_limit = 100
fitness_criterion = "max"
pop_size = 150
pop_size = 15000
[genome]
compatibility_disjoint = 1.0
@@ -26,7 +26,7 @@ node_delete_prob = 0
[species]
compatibility_threshold = 3.0
species_elitism = 2
species_max_stagnation = 15
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1

View File

@@ -1,20 +1,16 @@
from functools import partial
import numpy as np
import jax
from jax import jit
from configs import Configer
from neat.pipeline import Pipeline
from neat.function_factory import FunctionFactory
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def main():
config = Configer.load_config("xor.ini")
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory)
print(config)
pipeline = Pipeline(config)
forward_func = pipeline.ask()
# inputs = np.tile(xor_inputs, (150, 1, 1))
outputs = forward_func(xor_inputs)

View File

@@ -2,4 +2,4 @@
forward_way = "common"
[population]
fitness_threshold = -1e-2
fitness_threshold = 3.9999

View File

@@ -1,45 +1,26 @@
from typing import Callable, List
import time
import numpy as np
from configs import Configer
from neat import Pipeline
from neat.pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func: Callable) -> List[float]:
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list
return np.array(fitnesses) # returns a list
# @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main():
tic = time.time()
config = Configer.load_config("xor.ini")
print(config)
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory, seed=6)
pipeline = Pipeline(config, seed=6)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
total_time = time.time() - tic
compile_time = pipeline.function_factory.compile_time
total_it = pipeline.generation
mean_time_per_it = (total_time - compile_time) / total_it
evaluate_time = pipeline.evaluate_time
print(
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,3 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""

View File

@@ -1,10 +1,8 @@
import numpy as np
from jax import jit, vmap
from .genome.forward import create_forward
from .genome.utils import unflatten_connections
from .genome.graph import topological_sort
from .genome import create_forward, topological_sort, unflatten_connections
from .operations import create_next_generation_then_speciate
def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
@@ -15,8 +13,10 @@ class FunctionFactory:
Creates and compiles functions used in the NEAT pipeline.
"""
def __init__(self, config):
def __init__(self, config, jit_config):
self.config = config
self.jit_config = jit_config
self.func_dict = {}
self.function_info = {}
@@ -78,6 +78,24 @@ class FunctionFactory:
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'create_next_generation_then_speciate': {
'func': create_next_generation_then_speciate,
'lowers': [
{'shape': (2, ), 'type': np.uint32}, # rand_key
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
{'shape': ('P', ), 'type': np.int32}, # winner
{'shape': ('P', ), 'type': np.int32}, # loser
{'shape': ('P', ), 'type': bool}, # elite_mask
{'shape': ('P',), 'type': np.int32}, # new_node_keys
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
{'shape': ('S', ), 'type': np.int32}, # species_keys
{'shape': (), 'type': np.int32}, # new_species_key_start
"jit_config"
]
}
}
@@ -94,12 +112,19 @@ class FunctionFactory:
# prepare lower operands
lowers_operands = []
for lower in self.function_info[name]['lowers']:
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
if isinstance(lower, dict):
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
elif lower == "jit_config":
lowers_operands.append(self.jit_config)
else:
raise ValueError("Invalid lower operand")
# compile
compiled_func = jit(func).lower(*lowers_operands).compile()

View File

@@ -0,0 +1,7 @@
from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .forward import create_forward
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .genome import initialize_genomes, expand, expand_single

View File

@@ -1,34 +1,27 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
@jit
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@jit
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@jit
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@jit
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@jit
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
@@ -36,7 +29,6 @@ def maxabs_agg(z):
return z[max_abs_index]
@jit
def median_agg(z):
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
@@ -49,7 +41,6 @@ def median_agg(z):
return median
@jit
def mean_agg(z):
non_zero_mask = ~jnp.isnan(z)
valid_values_sum = sum_agg(z)

View File

@@ -10,7 +10,7 @@ import jax
from jax import numpy as jnp
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@@ -273,7 +273,8 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config
is_already_exist = con_idx != I_INT
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
u_cons = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])

View File

@@ -1,12 +1,11 @@
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
@jit
@@ -58,8 +57,3 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

171
neat/operations.py Normal file
View File

@@ -0,0 +1,171 @@
"""
contains operations on the population: creating the next generation and population speciation.
"""
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
center_nodes, center_cons, species_keys, new_species_key_start,
jit_config):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
new_node_keys, jit_config)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
@jit
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
def find_new_centers(i, carry):
i2s, cn, cc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
return i2s, cn, cc
idx2specie, center_nodes, center_cons = \
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def body_func(carry):
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i2s, scn, scc, sk, ck = jax.lax.cond(
sk[i] == I_INT, # whether the current species is existing or not
create_new_specie, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, cn, cc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to the new species
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, sk, ck = carry
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck
def speciate_by_threshold(carry):
i, i2s, cn, cc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, center_nodes, center_cons, species_keys
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,11 +1,13 @@
from functools import partial
import time
from typing import Union, Callable
import numpy as np
import jax
from configs.configer import Configer
from .genome.genome import initialize_genomes
from configs import Configer
from .genome import initialize_genomes, expand, expand_single
from .function_factory import FunctionFactory
from .species import SpeciesController
class Pipeline:
@@ -19,7 +21,7 @@ class Pipeline:
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.function_factory = function_factory or FunctionFactory(self.config)
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
self.symbols = {
'P': self.config['pop_size'],
@@ -31,8 +33,16 @@ class Pipeline:
self.generation = 0
self.best_genome = None
self.species_controller = SpeciesController(self.config)
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
print(self.config)
def ask(self):
"""
@@ -74,5 +84,105 @@ class Pipeline:
else:
raise NotImplementedError
def tell(self, fitnesses):
self.generation += 1
winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \
self.species_controller.ask(fitnesses, self.generation, self.symbols)
# node keys to be used in the mutation process
new_node_keys = np.arange(self.generation * self.config['pop_size'],
self.generation * self.config['pop_size'] + self.config['pop_size'])
# create the next generation and then speciate the population
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
self.get_func('create_next_generation_then_speciate') \
(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
center_cons, species_keys, species_key_start, self.jit_config)
# carry data to cpu
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation)
# expand the population if needed
self.expand()
# update randkey
self.randkey = jax.random.split(self.randkey)[0]
def expand(self):
"""
Expand the population if needed.
when the maximum node number >= N or the maximum connection number of >= C
the population will expand
"""
changed = False
pop_node_keys = self.pop_nodes[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.symbols['N']:
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
print(f"node expand to {self.symbols['N']}!")
changed = True
pop_con_keys = self.pop_cons[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.symbols['C']:
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
print(f"connection expand to {self.symbols['C']}!")
changed = True
if changed:
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C'])
def get_func(self, name):
return self.function_factory.get(name, self.symbols)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")

View File

@@ -1,168 +0,0 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first, argmin_with_mask
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
species_kwargs, mutate_kwargs):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
new_node_keys, **mutate_kwargs)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
pre_spe_center_cons, species_keys,
new_species_key_start, **species_kwargs)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
species_keys, new_species_key_start,
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
# prepare distance functions
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None)
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
def find_new_centers(i, carry):
i2s, scn, scc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
return i2s, scn, scc
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
def continue_execute_while(carry):
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def deal_with_each_center_genome(carry):
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
i2s, scn, scc, sk, ck = jax.lax.cond(
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
create_new_specie, # if not valid, create a new specie
update_exist_specie, # if valid, update the specie
(i, i2s, scn, scc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, scn, scc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to new specie
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
scn = scn.at[i].set(pop_nodes[idx])
scc = scc.at[i].set(pop_cons[idx])
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, scn, scc, sk, ck = carry
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
return i2s, scn, scc, sk, ck
def speciate_by_threshold(carry):
i, i2s, scn, scc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
close_enough_mask = o2p_distance < compatibility_threshold
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s, scn, scc
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
continue_execute_while,
deal_with_each_center_genome,
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
**mutate_kwargs):
# prepare functions
batch_crossover = vmap(crossover)
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn = pop_nodes[winner_part] # winner pop nodes
wpc = pop_cons[winner_part] # winner pop connections
lpn = pop_nodes[loser_part] # loser pop nodes
lpc = pop_cons[loser_part] # loser pop connections
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons

View File

@@ -1,7 +1,15 @@
from typing import List, Tuple, Dict, Union, Callable
"""
Species Controller in NEAT.
The code are modified from neat-python.
See
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
"""
from typing import List, Tuple, Dict
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
@@ -37,14 +45,13 @@ class SpeciesController:
def __init__(self, config):
self.config = config
self.species_elitism = self.config.neat.species.species_elitism
self.pop_size = self.config.neat.population.pop_size
self.max_stagnation = self.config.neat.species.max_stagnation
self.min_species_size = self.config.neat.species.min_species_size
self.genome_elitism = self.config.neat.species.genome_elitism
self.survival_threshold = self.config.neat.species.survival_threshold
self.species_elitism = self.config['species_elitism']
self.pop_size = self.config['pop_size']
self.max_stagnation = self.config['max_stagnation']
self.min_species_size = self.config['min_species_size']
self.genome_elitism = self.config['genome_elitism']
self.survival_threshold = self.config['survival_threshold']
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
@@ -55,9 +62,10 @@ class SpeciesController:
:return:
"""
pop_size = pop_nodes.shape[0]
species_id = next(self.species_idxer)
species_id = 0 # the first species
s = Species(species_id, 0)
members = np.array(list(range(pop_size)))
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
@@ -68,16 +76,14 @@ class SpeciesController:
:return:
"""
for sid, s in self.species.items():
# TODO: here use mean to measure the fitness of a species, but it may be other functions
s.member_fitnesses = s.get_fitnesses(fitnesses)
# s.fitness = np.mean(s.member_fitnesses)
# use the max score to represent the fitness of the species
s.fitness = np.max(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def __stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
:return: whether the species is stagnated
"""
@@ -88,7 +94,7 @@ class SpeciesController:
else:
prev_fitness = float('-inf')
if prev_fitness is None or s.fitness > prev_fitness:
if s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
@@ -110,7 +116,6 @@ class SpeciesController:
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
code modified from neat-python!
:param fitnesses:
:param generation:
:return: crossover_pair for next generation.
@@ -136,6 +141,8 @@ class SpeciesController:
# No species left.
assert remaining_species
# TODO: Too complex!
# Compute each species' member size in the next generation.
# Do not allow the fitness range to be zero, as we divide by it below.
@@ -185,6 +192,7 @@ class SpeciesController:
# only use good genomes to crossover
sorted_members = sorted_members[:repro_cutoff]
# TODO: Genome with higher fitness should be more likely to be selected?
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
part1.extend(sorted_members[list_idx1])
part2.extend(sorted_members[list_idx2])
@@ -197,32 +205,37 @@ class SpeciesController:
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
# the new specie created in this generation
s = Species(key, generation)
self.species[key] = s
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
def ask(self, fitnesses, generation, S, N, C):
def ask(self, fitnesses, generation, symbols):
self.__update_species_fitnesses(fitnesses)
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
pre_spe_center_cons = np.full((S, C, 4), np.nan)
species_keys = np.full((S,), I_INT)
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
species_keys = np.full((symbols['S'], ), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
pre_spe_center_nodes[idx] = specie.representative[0]
pre_spe_center_cons[idx] = specie.representative[1]
center_nodes[idx], center_cons[idx] = specie.representative
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
pre_spe_center_cons, species_keys, next_new_specie_key
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):