Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
[basic]
|
[basic]
|
||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 20
|
init_maximum_nodes = 50
|
||||||
init_maximum_connections = 20
|
init_maximum_connections = 50
|
||||||
init_maximum_species = 10
|
init_maximum_species = 10
|
||||||
expands_coe = 2.0
|
expand_coe = 2.0
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 100
|
generation_limit = 100
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 150
|
pop_size = 15000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
@@ -26,7 +26,7 @@ node_delete_prob = 0
|
|||||||
[species]
|
[species]
|
||||||
compatibility_threshold = 3.0
|
compatibility_threshold = 3.0
|
||||||
species_elitism = 2
|
species_elitism = 2
|
||||||
species_max_stagnation = 15
|
max_stagnation = 15
|
||||||
genome_elitism = 2
|
genome_elitism = 2
|
||||||
survival_threshold = 0.2
|
survival_threshold = 0.2
|
||||||
min_species_size = 1
|
min_species_size = 1
|
||||||
|
|||||||
@@ -1,20 +1,16 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax
|
|
||||||
from jax import jit
|
from jax import jit
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
from neat.pipeline import Pipeline
|
from neat.pipeline import Pipeline
|
||||||
from neat.function_factory import FunctionFactory
|
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
function_factory = FunctionFactory(config)
|
print(config)
|
||||||
pipeline = Pipeline(config, function_factory)
|
pipeline = Pipeline(config)
|
||||||
forward_func = pipeline.ask()
|
forward_func = pipeline.ask()
|
||||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
||||||
outputs = forward_func(xor_inputs)
|
outputs = forward_func(xor_inputs)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
forward_way = "common"
|
forward_way = "common"
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = -1e-2
|
fitness_threshold = 3.9999
|
||||||
@@ -1,45 +1,26 @@
|
|||||||
from typing import Callable, List
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
from neat import Pipeline
|
from neat.pipeline import Pipeline
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]])
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(forward_func: Callable) -> List[float]:
|
def evaluate(forward_func):
|
||||||
"""
|
"""
|
||||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
outs = forward_func(xor_inputs)
|
outs = forward_func(xor_inputs)
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
# print(fitnesses)
|
return np.array(fitnesses) # returns a list
|
||||||
return fitnesses.tolist() # returns a list
|
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
|
||||||
def main():
|
def main():
|
||||||
tic = time.time()
|
|
||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
print(config)
|
pipeline = Pipeline(config, seed=6)
|
||||||
function_factory = FunctionFactory(config)
|
|
||||||
pipeline = Pipeline(config, function_factory, seed=6)
|
|
||||||
nodes, cons = pipeline.auto_run(evaluate)
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
print(nodes, cons)
|
|
||||||
total_time = time.time() - tic
|
|
||||||
compile_time = pipeline.function_factory.compile_time
|
|
||||||
total_it = pipeline.generation
|
|
||||||
mean_time_per_it = (total_time - compile_time) / total_it
|
|
||||||
evaluate_time = pipeline.evaluate_time
|
|
||||||
print(
|
|
||||||
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
|
||||||
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||||
|
"""
|
||||||
@@ -1,10 +1,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
from .genome.forward import create_forward
|
from .genome import create_forward, topological_sort, unflatten_connections
|
||||||
from .genome.utils import unflatten_connections
|
from .operations import create_next_generation_then_speciate
|
||||||
from .genome.graph import topological_sort
|
|
||||||
|
|
||||||
|
|
||||||
def hash_symbols(symbols):
|
def hash_symbols(symbols):
|
||||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||||
@@ -15,8 +13,10 @@ class FunctionFactory:
|
|||||||
Creates and compiles functions used in the NEAT pipeline.
|
Creates and compiles functions used in the NEAT pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, jit_config):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.jit_config = jit_config
|
||||||
|
|
||||||
self.func_dict = {}
|
self.func_dict = {}
|
||||||
self.function_info = {}
|
self.function_info = {}
|
||||||
|
|
||||||
@@ -78,6 +78,24 @@ class FunctionFactory:
|
|||||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
|
||||||
|
'create_next_generation_then_speciate': {
|
||||||
|
'func': create_next_generation_then_speciate,
|
||||||
|
'lowers': [
|
||||||
|
{'shape': (2, ), 'type': np.uint32}, # rand_key
|
||||||
|
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||||
|
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||||
|
{'shape': ('P', ), 'type': np.int32}, # winner
|
||||||
|
{'shape': ('P', ), 'type': np.int32}, # loser
|
||||||
|
{'shape': ('P', ), 'type': bool}, # elite_mask
|
||||||
|
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||||
|
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||||
|
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||||
|
{'shape': ('S', ), 'type': np.int32}, # species_keys
|
||||||
|
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||||
|
"jit_config"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,12 +112,19 @@ class FunctionFactory:
|
|||||||
# prepare lower operands
|
# prepare lower operands
|
||||||
lowers_operands = []
|
lowers_operands = []
|
||||||
for lower in self.function_info[name]['lowers']:
|
for lower in self.function_info[name]['lowers']:
|
||||||
shape = list(lower['shape'])
|
if isinstance(lower, dict):
|
||||||
for i, s in enumerate(shape):
|
shape = list(lower['shape'])
|
||||||
if s in symbols:
|
for i, s in enumerate(shape):
|
||||||
shape[i] = symbols[s]
|
if s in symbols:
|
||||||
assert isinstance(shape[i], int)
|
shape[i] = symbols[s]
|
||||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
assert isinstance(shape[i], int)
|
||||||
|
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||||
|
|
||||||
|
elif lower == "jit_config":
|
||||||
|
lowers_operands.append(self.jit_config)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid lower operand")
|
||||||
|
|
||||||
# compile
|
# compile
|
||||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from .mutate import mutate
|
||||||
|
from .distance import distance
|
||||||
|
from .crossover import crossover
|
||||||
|
from .forward import create_forward
|
||||||
|
from .graph import topological_sort, check_cycles
|
||||||
|
from .utils import unflatten_connections
|
||||||
|
from .genome import initialize_genomes, expand, expand_single
|
||||||
@@ -1,34 +1,27 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
|
||||||
from jax import jit
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def sum_agg(z):
|
def sum_agg(z):
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
return jnp.sum(z, axis=0)
|
return jnp.sum(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def product_agg(z):
|
def product_agg(z):
|
||||||
z = jnp.where(jnp.isnan(z), 1, z)
|
z = jnp.where(jnp.isnan(z), 1, z)
|
||||||
return jnp.prod(z, axis=0)
|
return jnp.prod(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def max_agg(z):
|
def max_agg(z):
|
||||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||||
return jnp.max(z, axis=0)
|
return jnp.max(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def min_agg(z):
|
def min_agg(z):
|
||||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||||
return jnp.min(z, axis=0)
|
return jnp.min(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def maxabs_agg(z):
|
def maxabs_agg(z):
|
||||||
z = jnp.where(jnp.isnan(z), 0, z)
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
abs_z = jnp.abs(z)
|
abs_z = jnp.abs(z)
|
||||||
@@ -36,7 +29,6 @@ def maxabs_agg(z):
|
|||||||
return z[max_abs_index]
|
return z[max_abs_index]
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def median_agg(z):
|
def median_agg(z):
|
||||||
non_nan_mask = ~jnp.isnan(z)
|
non_nan_mask = ~jnp.isnan(z)
|
||||||
n = jnp.sum(non_nan_mask, axis=0)
|
n = jnp.sum(non_nan_mask, axis=0)
|
||||||
@@ -49,7 +41,6 @@ def median_agg(z):
|
|||||||
return median
|
return median
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def mean_agg(z):
|
def mean_agg(z):
|
||||||
non_zero_mask = ~jnp.isnan(z)
|
non_zero_mask = ~jnp.isnan(z)
|
||||||
valid_values_sum = sum_agg(z)
|
valid_values_sum = sum_agg(z)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import jax
|
|||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import jit, Array
|
from jax import jit, Array
|
||||||
|
|
||||||
from .utils import fetch_random, fetch_first, I_INT
|
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||||
from .graph import check_cycles
|
from .graph import check_cycles
|
||||||
|
|
||||||
@@ -273,7 +273,8 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config
|
|||||||
|
|
||||||
is_already_exist = con_idx != I_INT
|
is_already_exist = con_idx != I_INT
|
||||||
|
|
||||||
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
|
u_cons = unflatten_connections(nodes, cons)
|
||||||
|
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
|
||||||
|
|
||||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from functools import partial
|
import numpy as np
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp, Array
|
from jax import numpy as jnp, Array
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
@@ -58,8 +57,3 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
|||||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||||
return fetch_first(mask, default)
|
return fetch_first(mask, default)
|
||||||
|
|
||||||
@jit
|
|
||||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
|
||||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
|
||||||
min_idx = jnp.argmin(masked_arr)
|
|
||||||
return min_idx
|
|
||||||
171
neat/operations.py
Normal file
171
neat/operations.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
contains operations on the population: creating the next generation and population speciation.
|
||||||
|
"""
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit, vmap
|
||||||
|
|
||||||
|
from jax import Array
|
||||||
|
|
||||||
|
from .genome import distance, mutate, crossover
|
||||||
|
from .genome.utils import I_INT, fetch_first
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
|
||||||
|
center_nodes, center_cons, species_keys, new_species_key_start,
|
||||||
|
jit_config):
|
||||||
|
# create next generation
|
||||||
|
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
|
||||||
|
new_node_keys, jit_config)
|
||||||
|
|
||||||
|
# speciate
|
||||||
|
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
|
||||||
|
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
|
||||||
|
|
||||||
|
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
|
||||||
|
# prepare random keys
|
||||||
|
pop_size = pop_nodes.shape[0]
|
||||||
|
k1, k2 = jax.random.split(rand_key, 2)
|
||||||
|
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||||
|
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||||
|
|
||||||
|
# batch crossover
|
||||||
|
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
|
||||||
|
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
|
||||||
|
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
|
# batch mutation
|
||||||
|
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
|
||||||
|
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
|
||||||
|
|
||||||
|
# elitism don't mutate
|
||||||
|
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||||
|
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||||
|
|
||||||
|
return pop_nodes, pop_cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
|
||||||
|
"""
|
||||||
|
args:
|
||||||
|
pop_nodes: (pop_size, N, 5)
|
||||||
|
pop_cons: (pop_size, C, 4)
|
||||||
|
spe_center_nodes: (species_size, N, 5)
|
||||||
|
spe_center_cons: (species_size, C, 4)
|
||||||
|
"""
|
||||||
|
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
|
||||||
|
|
||||||
|
# prepare distance functions
|
||||||
|
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||||
|
s2p_distance_func = vmap(
|
||||||
|
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
||||||
|
)
|
||||||
|
|
||||||
|
# idx to specie key
|
||||||
|
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||||
|
|
||||||
|
# part 1: find new centers
|
||||||
|
# the distance between each species' center and each genome in population
|
||||||
|
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
||||||
|
|
||||||
|
def find_new_centers(i, carry):
|
||||||
|
i2s, cn, cc = carry
|
||||||
|
# find new center
|
||||||
|
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||||
|
|
||||||
|
# check species[i] exist or not
|
||||||
|
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||||
|
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||||
|
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||||
|
|
||||||
|
i2s = i2s.at[idx].set(species_keys[i])
|
||||||
|
cn = cn.at[i].set(pop_nodes[idx])
|
||||||
|
cc = cc.at[i].set(pop_cons[idx])
|
||||||
|
return i2s, cn, cc
|
||||||
|
|
||||||
|
idx2specie, center_nodes, center_cons = \
|
||||||
|
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
||||||
|
|
||||||
|
# part 2: assign members to each species
|
||||||
|
def cond_func(carry):
|
||||||
|
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||||
|
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||||
|
not_reach_species_upper_bounds = i < species_size
|
||||||
|
return not_all_assigned & not_reach_species_upper_bounds
|
||||||
|
|
||||||
|
def body_func(carry):
|
||||||
|
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||||
|
|
||||||
|
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||||
|
sk[i] == I_INT, # whether the current species is existing or not
|
||||||
|
create_new_specie, # if not existing, create a new specie
|
||||||
|
update_exist_specie, # if existing, update the specie
|
||||||
|
(i, i2s, cn, cc, sk, ck)
|
||||||
|
)
|
||||||
|
|
||||||
|
return i + 1, i2s, scn, scc, sk, ck
|
||||||
|
|
||||||
|
def create_new_specie(carry):
|
||||||
|
i, i2s, cn, cc, sk, ck = carry
|
||||||
|
|
||||||
|
# pick the first one who has not been assigned to any species
|
||||||
|
idx = fetch_first(i2s == I_INT)
|
||||||
|
|
||||||
|
# assign it to the new species
|
||||||
|
sk = sk.at[i].set(ck)
|
||||||
|
i2s = i2s.at[idx].set(ck)
|
||||||
|
|
||||||
|
# update center genomes
|
||||||
|
cn = cn.at[i].set(pop_nodes[idx])
|
||||||
|
cc = cc.at[i].set(pop_cons[idx])
|
||||||
|
|
||||||
|
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||||
|
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
|
||||||
|
|
||||||
|
def update_exist_specie(carry):
|
||||||
|
i, i2s, cn, cc, sk, ck = carry
|
||||||
|
|
||||||
|
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||||
|
|
||||||
|
return i2s, cn, cc, sk, ck
|
||||||
|
|
||||||
|
def speciate_by_threshold(carry):
|
||||||
|
i, i2s, cn, cc, sk = carry
|
||||||
|
|
||||||
|
# distance between such center genome and ppo genomes
|
||||||
|
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||||
|
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||||
|
|
||||||
|
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||||
|
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||||
|
return i2s
|
||||||
|
|
||||||
|
current_new_key = new_species_key_start
|
||||||
|
|
||||||
|
# update idx2specie
|
||||||
|
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
|
||||||
|
cond_func,
|
||||||
|
body_func,
|
||||||
|
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||||
|
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||||
|
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||||
|
|
||||||
|
return idx2specie, center_nodes, center_cons, species_keys
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||||
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||||
|
min_idx = jnp.argmin(masked_arr)
|
||||||
|
return min_idx
|
||||||
118
neat/pipeline.py
118
neat/pipeline.py
@@ -1,11 +1,13 @@
|
|||||||
from functools import partial
|
import time
|
||||||
|
from typing import Union, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
from configs.configer import Configer
|
from configs import Configer
|
||||||
from .genome.genome import initialize_genomes
|
from .genome import initialize_genomes, expand, expand_single
|
||||||
from .function_factory import FunctionFactory
|
from .function_factory import FunctionFactory
|
||||||
|
from .species import SpeciesController
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -19,7 +21,7 @@ class Pipeline:
|
|||||||
|
|
||||||
self.config = config # global config
|
self.config = config # global config
|
||||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||||
self.function_factory = function_factory or FunctionFactory(self.config)
|
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||||
|
|
||||||
self.symbols = {
|
self.symbols = {
|
||||||
'P': self.config['pop_size'],
|
'P': self.config['pop_size'],
|
||||||
@@ -31,8 +33,16 @@ class Pipeline:
|
|||||||
self.generation = 0
|
self.generation = 0
|
||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
|
|
||||||
|
self.species_controller = SpeciesController(self.config)
|
||||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||||
|
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||||
|
|
||||||
|
self.best_fitness = float('-inf')
|
||||||
|
self.best_genome = None
|
||||||
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
|
self.evaluate_time = 0
|
||||||
|
print(self.config)
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
"""
|
"""
|
||||||
@@ -74,5 +84,105 @@ class Pipeline:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tell(self, fitnesses):
|
||||||
|
self.generation += 1
|
||||||
|
|
||||||
|
winner, loser, elite_mask, center_nodes, center_cons, species_keys, species_key_start = \
|
||||||
|
self.species_controller.ask(fitnesses, self.generation, self.symbols)
|
||||||
|
|
||||||
|
# node keys to be used in the mutation process
|
||||||
|
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||||
|
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||||
|
|
||||||
|
# create the next generation and then speciate the population
|
||||||
|
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||||
|
self.get_func('create_next_generation_then_speciate') \
|
||||||
|
(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||||
|
center_cons, species_keys, species_key_start, self.jit_config)
|
||||||
|
|
||||||
|
# carry data to cpu
|
||||||
|
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||||
|
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||||
|
|
||||||
|
self.species_controller.tell(idx2specie, center_nodes, center_cons, species_keys, self.generation)
|
||||||
|
|
||||||
|
# expand the population if needed
|
||||||
|
self.expand()
|
||||||
|
|
||||||
|
# update randkey
|
||||||
|
self.randkey = jax.random.split(self.randkey)[0]
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
"""
|
||||||
|
Expand the population if needed.
|
||||||
|
when the maximum node number >= N or the maximum connection number of >= C
|
||||||
|
the population will expand
|
||||||
|
"""
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||||
|
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||||
|
max_node_size = np.max(pop_node_sizes)
|
||||||
|
if max_node_size >= self.symbols['N']:
|
||||||
|
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||||
|
print(f"node expand to {self.symbols['N']}!")
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
pop_con_keys = self.pop_cons[:, :, 0]
|
||||||
|
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||||
|
max_con_size = np.max(pop_node_sizes)
|
||||||
|
if max_con_size >= self.symbols['C']:
|
||||||
|
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||||
|
print(f"connection expand to {self.symbols['C']}!")
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
||||||
|
# don't forget to expand representation genome in species
|
||||||
|
for s in self.species_controller.species.values():
|
||||||
|
s.representative = expand_single(*s.representative, self.symbols['N'], self.symbols['C'])
|
||||||
|
|
||||||
def get_func(self, name):
|
def get_func(self, name):
|
||||||
return self.function_factory.get(name, self.symbols)
|
return self.function_factory.get(name, self.symbols)
|
||||||
|
|
||||||
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
|
for _ in range(self.config['generation_limit']):
|
||||||
|
forward_func = self.ask()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
fitnesses = fitness_func(forward_func)
|
||||||
|
self.evaluate_time += time.time() - tic
|
||||||
|
|
||||||
|
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||||
|
|
||||||
|
if analysis is not None:
|
||||||
|
if analysis == "default":
|
||||||
|
self.default_analysis(fitnesses)
|
||||||
|
else:
|
||||||
|
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||||
|
analysis(fitnesses)
|
||||||
|
|
||||||
|
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||||
|
print("Fitness limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
|
self.tell(fitnesses)
|
||||||
|
print("Generation limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
|
def default_analysis(self, fitnesses):
|
||||||
|
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||||
|
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||||
|
|
||||||
|
new_timestamp = time.time()
|
||||||
|
cost_time = new_timestamp - self.generation_timestamp
|
||||||
|
self.generation_timestamp = new_timestamp
|
||||||
|
|
||||||
|
max_idx = np.argmax(fitnesses)
|
||||||
|
if fitnesses[max_idx] > self.best_fitness:
|
||||||
|
self.best_fitness = fitnesses[max_idx]
|
||||||
|
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||||
|
|
||||||
|
print(f"Generation: {self.generation}",
|
||||||
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||||
|
|||||||
@@ -1,168 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from jax import jit, vmap
|
|
||||||
|
|
||||||
from jax import Array
|
|
||||||
|
|
||||||
from .genome import distance, mutate, crossover
|
|
||||||
from .genome.utils import I_INT, fetch_first, argmin_with_mask
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
|
||||||
new_node_keys,
|
|
||||||
pre_spe_center_nodes, pre_spe_center_cons, species_keys, new_species_key_start,
|
|
||||||
species_kwargs, mutate_kwargs):
|
|
||||||
# create next generation
|
|
||||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask,
|
|
||||||
new_node_keys, **mutate_kwargs)
|
|
||||||
|
|
||||||
# speciate
|
|
||||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = speciate(pop_nodes, pop_cons, pre_spe_center_nodes,
|
|
||||||
pre_spe_center_cons, species_keys,
|
|
||||||
new_species_key_start, **species_kwargs)
|
|
||||||
|
|
||||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def speciate(pop_nodes: Array, pop_cons: Array, spe_center_nodes: Array, spe_center_cons: Array,
|
|
||||||
species_keys, new_species_key_start,
|
|
||||||
disjoint_coe: float = 1., compatibility_coe: float = 0.5, compatibility_threshold=3.0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
args:
|
|
||||||
pop_nodes: (pop_size, N, 5)
|
|
||||||
pop_cons: (pop_size, C, 4)
|
|
||||||
spe_center_nodes: (species_size, N, 5)
|
|
||||||
spe_center_cons: (species_size, C, 4)
|
|
||||||
"""
|
|
||||||
pop_size, species_size = pop_nodes.shape[0], spe_center_nodes.shape[0]
|
|
||||||
|
|
||||||
# prepare distance functions
|
|
||||||
distance_with_args = partial(distance, disjoint_coe=disjoint_coe, compatibility_coe=compatibility_coe)
|
|
||||||
o2p_distance_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
|
||||||
s2p_distance_func = vmap(
|
|
||||||
o2p_distance_func, in_axes=(0, 0, None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
# idx to specie key
|
|
||||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
|
||||||
|
|
||||||
# part 1: find new centers
|
|
||||||
# the distance between each species' center and each genome in population
|
|
||||||
s2p_distance = s2p_distance_func(spe_center_nodes, spe_center_cons, pop_nodes, pop_cons)
|
|
||||||
|
|
||||||
def find_new_centers(i, carry):
|
|
||||||
i2s, scn, scc = carry
|
|
||||||
# find new center
|
|
||||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
|
||||||
|
|
||||||
# check species[i] exist or not
|
|
||||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
|
||||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
|
||||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
|
||||||
|
|
||||||
i2s = i2s.at[idx].set(species_keys[i])
|
|
||||||
scn = scn.at[i].set(pop_nodes[idx])
|
|
||||||
scc = scc.at[i].set(pop_cons[idx])
|
|
||||||
return i2s, scn, scc
|
|
||||||
|
|
||||||
idx2specie, spe_center_nodes, spe_center_cons = jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, spe_center_nodes, spe_center_cons))
|
|
||||||
|
|
||||||
def continue_execute_while(carry):
|
|
||||||
i, i2s, scn, scc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
|
||||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
|
||||||
not_reach_species_upper_bounds = i < species_size
|
|
||||||
return not_all_assigned & not_reach_species_upper_bounds
|
|
||||||
|
|
||||||
def deal_with_each_center_genome(carry):
|
|
||||||
i, i2s, scn, scc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
|
||||||
center_nodes, center_cons = spe_center_nodes[i], spe_center_cons[i]
|
|
||||||
|
|
||||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
|
||||||
jnp.all(jnp.isnan(center_nodes)), # whether the center genome is valid
|
|
||||||
create_new_specie, # if not valid, create a new specie
|
|
||||||
update_exist_specie, # if valid, update the specie
|
|
||||||
(i, i2s, scn, scc, sk, ck)
|
|
||||||
)
|
|
||||||
|
|
||||||
return i + 1, i2s, scn, scc, sk, ck
|
|
||||||
|
|
||||||
def create_new_specie(carry):
|
|
||||||
i, i2s, scn, scc, sk, ck = carry
|
|
||||||
# pick the first one who has not been assigned to any species
|
|
||||||
idx = fetch_first(i2s == I_INT)
|
|
||||||
|
|
||||||
# assign it to new specie
|
|
||||||
sk = sk.at[i].set(ck)
|
|
||||||
i2s = i2s.at[idx].set(ck)
|
|
||||||
|
|
||||||
# update center genomes
|
|
||||||
scn = scn.at[i].set(pop_nodes[idx])
|
|
||||||
scc = scc.at[i].set(pop_cons[idx])
|
|
||||||
|
|
||||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
|
||||||
return i2s, scn, scc, sk, ck + 1 # change to next new speciate key
|
|
||||||
|
|
||||||
def update_exist_specie(carry):
|
|
||||||
i, i2s, scn, scc, sk, ck = carry
|
|
||||||
|
|
||||||
i2s, scn, scc = speciate_by_threshold((i, i2s, scn, scc, sk))
|
|
||||||
return i2s, scn, scc, sk, ck
|
|
||||||
|
|
||||||
def speciate_by_threshold(carry):
|
|
||||||
i, i2s, scn, scc, sk = carry
|
|
||||||
# distance between such center genome and ppo genomes
|
|
||||||
o2p_distance = o2p_distance_func(scn[i], scc[i], pop_nodes, pop_cons)
|
|
||||||
close_enough_mask = o2p_distance < compatibility_threshold
|
|
||||||
|
|
||||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
|
||||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
|
||||||
return i2s, scn, scc
|
|
||||||
|
|
||||||
current_new_key = new_species_key_start
|
|
||||||
|
|
||||||
# update idx2specie
|
|
||||||
_, idx2specie, spe_center_nodes, spe_center_cons, species_keys, new_species_key_start = jax.lax.while_loop(
|
|
||||||
continue_execute_while,
|
|
||||||
deal_with_each_center_genome,
|
|
||||||
(0, idx2specie, spe_center_nodes, spe_center_cons, species_keys, current_new_key)
|
|
||||||
)
|
|
||||||
|
|
||||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
|
||||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
|
||||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
|
||||||
|
|
||||||
return idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner_part, loser_part, elite_mask, new_node_keys,
|
|
||||||
**mutate_kwargs):
|
|
||||||
# prepare functions
|
|
||||||
batch_crossover = vmap(crossover)
|
|
||||||
mutate_with_args = vmap(partial(mutate, **mutate_kwargs))
|
|
||||||
|
|
||||||
pop_size = pop_nodes.shape[0]
|
|
||||||
k1, k2 = jax.random.split(rand_key, 2)
|
|
||||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
|
||||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
|
||||||
|
|
||||||
# batch crossover
|
|
||||||
wpn = pop_nodes[winner_part] # winner pop nodes
|
|
||||||
wpc = pop_cons[winner_part] # winner pop connections
|
|
||||||
lpn = pop_nodes[loser_part] # loser pop nodes
|
|
||||||
lpc = pop_cons[loser_part] # loser pop connections
|
|
||||||
|
|
||||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
|
||||||
|
|
||||||
m_npn, m_npc = mutate_with_args(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
|
||||||
|
|
||||||
# elitism don't mutate
|
|
||||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
|
||||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
|
||||||
|
|
||||||
return pop_nodes, pop_cons
|
|
||||||
@@ -1,7 +1,15 @@
|
|||||||
from typing import List, Tuple, Dict, Union, Callable
|
"""
|
||||||
|
Species Controller in NEAT.
|
||||||
|
The code are modified from neat-python.
|
||||||
|
See
|
||||||
|
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
|
||||||
|
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
|
||||||
|
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
|
||||||
import jax
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
@@ -37,14 +45,13 @@ class SpeciesController:
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.species_elitism = self.config.neat.species.species_elitism
|
self.species_elitism = self.config['species_elitism']
|
||||||
self.pop_size = self.config.neat.population.pop_size
|
self.pop_size = self.config['pop_size']
|
||||||
self.max_stagnation = self.config.neat.species.max_stagnation
|
self.max_stagnation = self.config['max_stagnation']
|
||||||
self.min_species_size = self.config.neat.species.min_species_size
|
self.min_species_size = self.config['min_species_size']
|
||||||
self.genome_elitism = self.config.neat.species.genome_elitism
|
self.genome_elitism = self.config['genome_elitism']
|
||||||
self.survival_threshold = self.config.neat.species.survival_threshold
|
self.survival_threshold = self.config['survival_threshold']
|
||||||
|
|
||||||
self.species_idxer = count(0)
|
|
||||||
self.species: Dict[int, Species] = {} # species_id -> species
|
self.species: Dict[int, Species] = {} # species_id -> species
|
||||||
|
|
||||||
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
||||||
@@ -55,9 +62,10 @@ class SpeciesController:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pop_size = pop_nodes.shape[0]
|
pop_size = pop_nodes.shape[0]
|
||||||
species_id = next(self.species_idxer)
|
species_id = 0 # the first species
|
||||||
s = Species(species_id, 0)
|
s = Species(species_id, 0)
|
||||||
members = np.array(list(range(pop_size)))
|
members = np.array(list(range(pop_size)))
|
||||||
|
|
||||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||||
self.species[species_id] = s
|
self.species[species_id] = s
|
||||||
|
|
||||||
@@ -68,16 +76,14 @@ class SpeciesController:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for sid, s in self.species.items():
|
for sid, s in self.species.items():
|
||||||
# TODO: here use mean to measure the fitness of a species, but it may be other functions
|
|
||||||
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||||
# s.fitness = np.mean(s.member_fitnesses)
|
# use the max score to represent the fitness of the species
|
||||||
s.fitness = np.max(s.member_fitnesses)
|
s.fitness = np.max(s.member_fitnesses)
|
||||||
s.fitness_history.append(s.fitness)
|
s.fitness_history.append(s.fitness)
|
||||||
s.adjusted_fitness = None
|
s.adjusted_fitness = None
|
||||||
|
|
||||||
def __stagnation(self, generation):
|
def __stagnation(self, generation):
|
||||||
"""
|
"""
|
||||||
code modified from neat-python!
|
|
||||||
:param generation:
|
:param generation:
|
||||||
:return: whether the species is stagnated
|
:return: whether the species is stagnated
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +94,7 @@ class SpeciesController:
|
|||||||
else:
|
else:
|
||||||
prev_fitness = float('-inf')
|
prev_fitness = float('-inf')
|
||||||
|
|
||||||
if prev_fitness is None or s.fitness > prev_fitness:
|
if s.fitness > prev_fitness:
|
||||||
s.last_improved = generation
|
s.last_improved = generation
|
||||||
|
|
||||||
species_data.append((sid, s))
|
species_data.append((sid, s))
|
||||||
@@ -110,7 +116,6 @@ class SpeciesController:
|
|||||||
|
|
||||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||||
"""
|
"""
|
||||||
code modified from neat-python!
|
|
||||||
:param fitnesses:
|
:param fitnesses:
|
||||||
:param generation:
|
:param generation:
|
||||||
:return: crossover_pair for next generation.
|
:return: crossover_pair for next generation.
|
||||||
@@ -136,6 +141,8 @@ class SpeciesController:
|
|||||||
# No species left.
|
# No species left.
|
||||||
assert remaining_species
|
assert remaining_species
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Too complex!
|
||||||
# Compute each species' member size in the next generation.
|
# Compute each species' member size in the next generation.
|
||||||
|
|
||||||
# Do not allow the fitness range to be zero, as we divide by it below.
|
# Do not allow the fitness range to be zero, as we divide by it below.
|
||||||
@@ -185,6 +192,7 @@ class SpeciesController:
|
|||||||
# only use good genomes to crossover
|
# only use good genomes to crossover
|
||||||
sorted_members = sorted_members[:repro_cutoff]
|
sorted_members = sorted_members[:repro_cutoff]
|
||||||
|
|
||||||
|
# TODO: Genome with higher fitness should be more likely to be selected?
|
||||||
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
||||||
part1.extend(sorted_members[list_idx1])
|
part1.extend(sorted_members[list_idx1])
|
||||||
part2.extend(sorted_members[list_idx2])
|
part2.extend(sorted_members[list_idx2])
|
||||||
@@ -197,32 +205,37 @@ class SpeciesController:
|
|||||||
|
|
||||||
return winner_part, loser_part, np.array(elite_mask)
|
return winner_part, loser_part, np.array(elite_mask)
|
||||||
|
|
||||||
def tell(self, idx2specie, spe_center_nodes, spe_center_cons, species_keys, generation):
|
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
|
||||||
for idx, key in enumerate(species_keys):
|
for idx, key in enumerate(species_keys):
|
||||||
if key == I_INT:
|
if key == I_INT:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
members = np.where(idx2specie == key)[0]
|
members = np.where(idx2specie == key)[0]
|
||||||
assert len(members) > 0
|
assert len(members) > 0
|
||||||
|
|
||||||
if key not in self.species:
|
if key not in self.species:
|
||||||
|
# the new specie created in this generation
|
||||||
s = Species(key, generation)
|
s = Species(key, generation)
|
||||||
self.species[key] = s
|
self.species[key] = s
|
||||||
|
|
||||||
self.species[key].update((spe_center_nodes[idx], spe_center_cons[idx]), members)
|
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
|
||||||
|
|
||||||
def ask(self, fitnesses, generation, S, N, C):
|
def ask(self, fitnesses, generation, symbols):
|
||||||
self.__update_species_fitnesses(fitnesses)
|
self.__update_species_fitnesses(fitnesses)
|
||||||
winner_part, loser_part, elite_mask = self.__reproduce(fitnesses, generation)
|
|
||||||
pre_spe_center_nodes = np.full((S, N, 5), np.nan)
|
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
|
||||||
pre_spe_center_cons = np.full((S, C, 4), np.nan)
|
|
||||||
species_keys = np.full((S,), I_INT)
|
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
|
||||||
|
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
|
||||||
|
species_keys = np.full((symbols['S'], ), I_INT)
|
||||||
|
|
||||||
for idx, (key, specie) in enumerate(self.species.items()):
|
for idx, (key, specie) in enumerate(self.species.items()):
|
||||||
pre_spe_center_nodes[idx] = specie.representative[0]
|
center_nodes[idx], center_cons[idx] = specie.representative
|
||||||
pre_spe_center_cons[idx] = specie.representative[1]
|
|
||||||
species_keys[idx] = key
|
species_keys[idx] = key
|
||||||
|
|
||||||
next_new_specie_key = max(self.species.keys()) + 1
|
next_new_specie_key = max(self.species.keys()) + 1
|
||||||
return winner_part, loser_part, elite_mask, pre_spe_center_nodes, \
|
|
||||||
pre_spe_center_cons, species_keys, next_new_specie_key
|
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
|
||||||
|
|
||||||
|
|
||||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||||
|
|||||||
Reference in New Issue
Block a user