Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
init_maximum_nodes = 20
|
||||
init_maximum_connections = 20
|
||||
init_maximum_nodes = 50
|
||||
init_maximum_connections = 50
|
||||
init_maximum_species = 10
|
||||
expands_coe = 2.0
|
||||
expand_coe = 2.0
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
|
||||
@@ -12,7 +12,7 @@ batch_size = 4
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 100
|
||||
fitness_criterion = "max"
|
||||
pop_size = 150
|
||||
pop_size = 15000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
@@ -26,7 +26,7 @@ node_delete_prob = 0
|
||||
[species]
|
||||
compatibility_threshold = 3.0
|
||||
species_elitism = 2
|
||||
species_max_stagnation = 15
|
||||
max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit
|
||||
|
||||
from configs import Configer
|
||||
from neat.pipeline import Pipeline
|
||||
from neat.function_factory import FunctionFactory
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
function_factory = FunctionFactory(config)
|
||||
pipeline = Pipeline(config, function_factory)
|
||||
print(config)
|
||||
pipeline = Pipeline(config)
|
||||
forward_func = pipeline.ask()
|
||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
||||
outputs = forward_func(xor_inputs)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
forward_way = "common"
|
||||
|
||||
[population]
|
||||
fitness_threshold = -1e-2
|
||||
fitness_threshold = 3.9999
|
||||
@@ -1,45 +1,26 @@
|
||||
from typing import Callable, List
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from neat import Pipeline
|
||||
from neat.pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
xor_outputs = np.array([[0], [1], [1], [0]])
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func: Callable) -> List[float]:
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
# print(fitnesses)
|
||||
return fitnesses.tolist() # returns a list
|
||||
return np.array(fitnesses) # returns a list
|
||||
|
||||
|
||||
# @using_cprofile
|
||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
tic = time.time()
|
||||
config = Configer.load_config("xor.ini")
|
||||
print(config)
|
||||
function_factory = FunctionFactory(config)
|
||||
pipeline = Pipeline(config, function_factory, seed=6)
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
print(nodes, cons)
|
||||
total_time = time.time() - tic
|
||||
compile_time = pipeline.function_factory.compile_time
|
||||
total_it = pipeline.generation
|
||||
mean_time_per_it = (total_time - compile_time) / total_it
|
||||
evaluate_time = pipeline.evaluate_time
|
||||
print(
|
||||
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
||||
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
@@ -1,10 +1,8 @@
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome.forward import create_forward
|
||||
from .genome.utils import unflatten_connections
|
||||
from .genome.graph import topological_sort
|
||||
|
||||
from .genome import create_forward, topological_sort, unflatten_connections
|
||||
from .operations import create_next_generation_then_speciate
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
@@ -15,8 +13,10 @@ class FunctionFactory:
|
||||
Creates and compiles functions used in the NEAT pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, jit_config):
|
||||
self.config = config
|
||||
self.jit_config = jit_config
|
||||
|
||||
self.func_dict = {}
|
||||
self.function_info = {}
|
||||
|
||||
@@ -78,6 +78,24 @@ class FunctionFactory:
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
'create_next_generation_then_speciate': {
|
||||
'func': create_next_generation_then_speciate,
|
||||
'lowers': [
|
||||
{'shape': (2, ), 'type': np.uint32}, # rand_key
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||
{'shape': ('P', ), 'type': np.int32}, # winner
|
||||
{'shape': ('P', ), 'type': np.int32}, # loser
|
||||
{'shape': ('P', ), 'type': bool}, # elite_mask
|
||||
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||
{'shape': ('S', ), 'type': np.int32}, # species_keys
|
||||
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||
"jit_config"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +112,7 @@ class FunctionFactory:
|
||||
# prepare lower operands
|
||||
lowers_operands = []
|
||||
for lower in self.function_info[name]['lowers']:
|
||||
if isinstance(lower, dict):
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
@@ -101,6 +120,12 @@ class FunctionFactory:
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
|
||||
elif lower == "jit_config":
|
||||
lowers_operands.append(self.jit_config)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid lower operand")
|
||||
|
||||
# compile
|
||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .forward import create_forward
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
@@ -1,34 +1,27 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
@@ -36,7 +29,6 @@ def maxabs_agg(z):
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
@jit
|
||||
def median_agg(z):
|
||||
non_nan_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_nan_mask, axis=0)
|
||||
@@ -49,7 +41,6 @@ def median_agg(z):
|
||||
return median
|
||||
|
||||
|
||||
@jit
|
||||
def mean_agg(z):
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
valid_values_sum = sum_agg(z)
|
||||
|
||||
@@ -10,7 +10,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
|
||||
@@ -273,7 +273,8 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
|
||||
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
|
||||
u_cons = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
|
||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -58,8 +57,3 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
171
neat/operations.py
Normal file
171
neat/operations.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
|
||||
center_nodes, center_cons, species_keys, new_species_key_start,
|
||||
jit_config):
|
||||
# create next generation
|
||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
|
||||
new_node_keys, jit_config)
|
||||
|
||||
# speciate
|
||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
|
||||
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
|
||||
|
||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
|
||||
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
|
||||
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# batch mutation
|
||||
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
|
||||
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, cn, cc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_keys[i])
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
return i2s, cn, cc
|
||||
|
||||
idx2specie, center_nodes, center_cons = \
|
||||
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||
sk[i] == I_INT, # whether the current species is existing or not
|
||||
create_new_specie, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, sk, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, sk, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to the new species
|
||||
sk = sk.at[i].set(ck)
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
|
||||
return i2s, cn, cc, sk, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, sk = carry
|
||||
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||
return i2s
|
||||
|
||||
current_new_key = new_species_key_start
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
118
neat/pipeline.py
118
neat/pipeline.py
@@ -1,11 +1,13 @@
|
||||
from functools import partial
|
||||
import time
|
||||
from typing import Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
|
||||
from configs.configer import Configer
|
||||
from .genome.genome import initialize_genomes
|
||||
from configs import Configer
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
from .function_factory import FunctionFactory
|
||||
from .species import SpeciesController
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -19,7 +21,7 @@ class Pipeline:
|
||||
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.function_factory = function_factory or FunctionFactory(self.config)
|
||||
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||
|
||||
self.symbols = {
|
||||
'P': self.config['pop_size'],
|
||||
@@ -31,8 +33,16 @@ class Pipeline:
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.species_controller = SpeciesController(self.config)
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
print(self.config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
@@ -74,5 +84,105 @@ class Pipeline:
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \
|
||||
self.species_controller.ask(fitnesses, self.generation, self.symbols)
|
||||
|
||||
# node keys to be used in the mutation process
|
||||
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||
|
||||
# create the next generation and then speciate the population
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
self.get_func('create_next_generation_then_speciate') \
|
||||
(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||
center_cons, species_keys, species_key_start, self.jit_config)
|
||||
|
||||
# carry data to cpu
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||
|
||||
self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation)
|
||||
|
||||
# expand the population if needed
|
||||
self.expand()
|
||||
|
||||
# update randkey
|
||||
self.randkey = jax.random.split(self.randkey)[0]
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
when the maximum node number >= N or the maximum connection number of >= C
|
||||
the population will expand
|
||||
"""
|
||||
changed = False
|
||||
|
||||
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
if max_node_size >= self.symbols['N']:
|
||||
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||
print(f"node expand to {self.symbols['N']}!")
|
||||
changed = True
|
||||
|
||||
pop_con_keys = self.pop_cons[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.symbols['C']:
|
||||
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||
print(f"connection expand to {self.symbols['C']}!")
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C'])
|
||||
|
||||
def get_func(self, name):
|
||||
return self.function_factory.get(name, self.symbols)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
forward_func = self.ask()
|
||||
|
||||
tic = time.time()
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys,
|
||||
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
|
||||
species_kwargs, mutate_kwargs):
|
||||
# create next generation
|
||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys, **mutate_kwargs)
|
||||
|
||||
# speciate
|
||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
|
||||
pre_spe_center_cons, species_keys,
|
||||
new_species_key_start, **species_kwargs)
|
||||
|
||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
||||
species_keys, new_species_key_start,
|
||||
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
||||
):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
||||
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None)
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, scn, scc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_keys[i])
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
return i2s, scn, scc
|
||||
|
||||
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
|
||||
|
||||
def continue_execute_while(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def deal_with_each_center_genome(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
||||
|
||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
||||
create_new_specie, # if not valid, create a new specie
|
||||
update_exist_specie, # if valid, update the specie
|
||||
(i, i2s, scn, scc, sk, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, sk, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to new specie
|
||||
sk = sk.at[i].set(ck)
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
scn = scn.at[i].set(pop_nodes[idx])
|
||||
scc = scc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
||||
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, scn, scc, sk, ck = carry
|
||||
|
||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
||||
return i2s, scn, scc, sk, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, scn, scc, sk = carry
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
||||
close_enough_mask = o2p_distance < compatibility_threshold
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||
return i2s, scn, scc
|
||||
|
||||
current_new_key = new_species_key_start
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
|
||||
continue_execute_while,
|
||||
deal_with_each_center_genome,
|
||||
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||
|
||||
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
|
||||
**mutate_kwargs):
|
||||
# prepare functions
|
||||
batch_crossover = vmap(crossover)
|
||||
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
|
||||
|
||||
pop_size = pop_nodes.shape[0]
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn = pop_nodes[winner_part] # winner pop nodes
|
||||
wpc = pop_cons[winner_part] # winner pop connections
|
||||
lpn = pop_nodes[loser_part] # loser pop nodes
|
||||
lpc = pop_cons[loser_part] # loser pop connections
|
||||
|
||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
@@ -1,7 +1,15 @@
|
||||
from typing import List, Tuple, Dict, Union, Callable
|
||||
"""
|
||||
Species Controller in NEAT.
|
||||
The code are modified from neat-python.
|
||||
See
|
||||
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
from itertools import count
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
@@ -37,14 +45,13 @@ class SpeciesController:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.species_elitism = self.config.neat.species.species_elitism
|
||||
self.pop_size = self.config.neat.population.pop_size
|
||||
self.max_stagnation = self.config.neat.species.max_stagnation
|
||||
self.min_species_size = self.config.neat.species.min_species_size
|
||||
self.genome_elitism = self.config.neat.species.genome_elitism
|
||||
self.survival_threshold = self.config.neat.species.survival_threshold
|
||||
self.species_elitism = self.config['species_elitism']
|
||||
self.pop_size = self.config['pop_size']
|
||||
self.max_stagnation = self.config['max_stagnation']
|
||||
self.min_species_size = self.config['min_species_size']
|
||||
self.genome_elitism = self.config['genome_elitism']
|
||||
self.survival_threshold = self.config['survival_threshold']
|
||||
|
||||
self.species_idxer = count(0)
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
||||
@@ -55,9 +62,10 @@ class SpeciesController:
|
||||
:return:
|
||||
"""
|
||||
pop_size = pop_nodes.shape[0]
|
||||
species_id = next(self.species_idxer)
|
||||
species_id = 0 # the first species
|
||||
s = Species(species_id, 0)
|
||||
members = np.array(list(range(pop_size)))
|
||||
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
@@ -68,16 +76,14 @@ class SpeciesController:
|
||||
:return:
|
||||
"""
|
||||
for sid, s in self.species.items():
|
||||
# TODO: here use mean to measure the fitness of a species, but it may be other functions
|
||||
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||
# s.fitness = np.mean(s.member_fitnesses)
|
||||
# use the max score to represent the fitness of the species
|
||||
s.fitness = np.max(s.member_fitnesses)
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param generation:
|
||||
:return: whether the species is stagnated
|
||||
"""
|
||||
@@ -88,7 +94,7 @@ class SpeciesController:
|
||||
else:
|
||||
prev_fitness = float('-inf')
|
||||
|
||||
if prev_fitness is None or s.fitness > prev_fitness:
|
||||
if s.fitness > prev_fitness:
|
||||
s.last_improved = generation
|
||||
|
||||
species_data.append((sid, s))
|
||||
@@ -110,7 +116,6 @@ class SpeciesController:
|
||||
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
code modified from neat-python!
|
||||
:param fitnesses:
|
||||
:param generation:
|
||||
:return: crossover_pair for next generation.
|
||||
@@ -136,6 +141,8 @@ class SpeciesController:
|
||||
# No species left.
|
||||
assert remaining_species
|
||||
|
||||
|
||||
# TODO: Too complex!
|
||||
# Compute each species' member size in the next generation.
|
||||
|
||||
# Do not allow the fitness range to be zero, as we divide by it below.
|
||||
@@ -185,6 +192,7 @@ class SpeciesController:
|
||||
# only use good genomes to crossover
|
||||
sorted_members = sorted_members[:repro_cutoff]
|
||||
|
||||
# TODO: Genome with higher fitness should be more likely to be selected?
|
||||
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
||||
part1.extend(sorted_members[list_idx1])
|
||||
part2.extend(sorted_members[list_idx2])
|
||||
@@ -197,32 +205,37 @@ class SpeciesController:
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
|
||||
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
|
||||
if key not in self.species:
|
||||
# the new specie created in this generation
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
|
||||
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, S, N, C):
|
||||
def ask(self, fitnesses, generation, symbols):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
|
||||
pre_spe_center_cons = np.full((S, C, 4), np.nan)
|
||||
species_keys = np.full((S,), I_INT)
|
||||
|
||||
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
|
||||
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
|
||||
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
|
||||
species_keys = np.full((symbols['S'], ), I_INT)
|
||||
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
pre_spe_center_nodes[idx] = specie.representative[0]
|
||||
pre_spe_center_cons[idx] = specie.representative[1]
|
||||
center_nodes[idx], center_cons[idx] = specie.representative
|
||||
species_keys[idx] = key
|
||||
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
|
||||
pre_spe_center_cons, species_keys, next_new_specie_key
|
||||
|
||||
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
|
||||
Reference in New Issue
Block a user