perfect! fix bug about jax auto recompile

add task xor-3d
This commit is contained in:
wls2002
2023-07-02 22:15:26 +08:00
parent e711146f41
commit c4d34e877b
11 changed files with 234 additions and 104 deletions

View File

@@ -2,7 +2,7 @@
contains operations on a single genome. e.g. forward, mutate, crossover, etc. contains operations on a single genome. e.g. forward, mutate, crossover, etc.
""" """
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate, tell from .population import update_species, create_next_generation, speciate, tell, initialize
from .genome.activations import act_name2func from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func from .genome.aggregations import agg_name2func

View File

@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \ assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!" f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan) pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
pop_cons = np.full((config['pop_size'], C, 4), np.nan) pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
input_idx = config['input_idx'] input_idx = config['input_idx']
output_idx = config['output_idx'] output_idx = config['output_idx']
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
pop_cons[:, :p, 0] = grid_a pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'], pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p)) size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1 pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons return pop_nodes, pop_cons

View File

@@ -1,20 +1,88 @@
""" """
Contains operations on the population: creating the next generation and population speciation. Contains operations on the population: creating the next generation and population speciation.
These im..... The value tuple (P, N, C, S) is determined when the algorithm is initialized.
P: population size
N: maximum number of nodes in any genome
C: maximum number of connections in any genome
S: maximum number of species in NEAT
These arrays are used in the algorithm:
fitness: Array[(P,), float], the fitness of each individual
randkey: Array[2, uint], the random key
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
idx2species: Array[(P,), float], map the individual to its species keys
center_nodes: Array[(S, N, 5), float], the center nodes of each species
center_cons: Array[(S, C, 4), float], the center connections of each species
generation: int, the current generation
next_node_key: float, the next of the next node
next_species_key: float, the next of the next species
jit_config: Configer, the config used in jit-able functions
""" """
# TODO: Complete python doc # TODO: Complete python doc
import numpy as np
import jax import jax
from jax import jit, vmap, Array, numpy as jnp from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
def initialize(config):
"""
initialize the states of NEAT.
"""
P = config['pop_size']
N = config['maximum_nodes']
C = config['maximum_connections']
S = config['maximum_species']
randkey = jax.random.PRNGKey(config['random_seed'])
np.random.seed(config['random_seed'])
pop_nodes, pop_cons = initialize_genomes(N, C, config)
species_info = np.full((S, 4), np.nan, dtype=np.float32)
species_info[0, :] = 0, -np.inf, 0, P
idx2species = np.zeros(P, dtype=np.float32)
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
center_nodes[0, :, :] = pop_nodes[0, :, :]
center_cons[0, :, :] = pop_cons[0, :, :]
generation = np.asarray(0, dtype=np.int32)
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
next_species_key = np.asarray(1, dtype=np.float32)
return jax.device_put([
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
])
@jit @jit
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, def tell(fitness,
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
jit_config): jit_config):
"""
Main update function in NEAT.
"""
generation += 1 generation += 1
k1, k2, randkey = jax.random.split(randkey, 3) k1, k2, randkey = jax.random.split(randkey, 3)
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
update_species(k1, fitness, species_info, idx2species, center_nodes, update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config) center_cons, generation, jit_config)
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, next_node_key, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser, idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
elite_mask, generation, jit_config) pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
idx2species, center_nodes, center_cons, species_info = speciate( return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config): def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
""" """
args: args:
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
return winner, loser, elite_mask return winner, loser, elite_mask
@jit def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
# prepare random keys # prepare random keys
pop_size = pop_nodes.shape[0] pop_size = pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + generation * pop_size new_node_keys = jnp.arange(pop_size) + next_node_key
k1, k2 = jax.random.split(rand_key, 2) k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size) crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn) pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc) pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons # update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return pop_nodes, pop_cons, next_node_key
@jit def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
""" """
args: args:
pop_nodes: (pop_size, N, 5) pop_nodes: (pop_size, N, 5)
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes # the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size, ), jnp.inf) o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers # step 1: find new centers
def cond_func(carry): def cond_func(carry):
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
# part 2: assign members to each species # part 2: assign members to each species
def cond_func(carry): def cond_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si) # jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
current_species_existed = ~jnp.isnan(si[i, 0]) current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s)) not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size not_reach_species_upper_bounds = i < species_size
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds) return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry): def body_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond( _, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, o2c, ck) (i, i2s, cn, cc, si, o2c, nsk)
) )
return i + 1, i2s, scn, scc, si, o2c, ck return i + 1, i2s, scn, scc, si, o2c, nsk
def create_new_species(carry): def create_new_species(carry):
i, i2s, cn, cc, si, o2c, ck = carry i, i2s, cn, cc, si, o2c, nsk = carry
# pick the first one who has not been assigned to any species # pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s)) idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species # assign it to the new species
# [key, best score, last update generation, members_count] # [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0])) si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(ck) i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0) o2c = o2c.at[idx].set(0)
# update center genomes # update center genomes
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# when a new species is created, it needs to be updated, thus do not change i # when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry): def update_exist_specie(carry):
i, i2s, cn, cc, si, o2c, ck = carry i, i2s, cn, cc, si, o2c, nsk = carry
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c)) i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# turn to next species # turn to next species
return i + 1, i2s, cn, cc, si, o2c, ck return i + 1, i2s, cn, cc, si, o2c, nsk
def speciate_by_threshold(carry): def speciate_by_threshold(carry):
i, i2s, cn, cc, si, o2c = carry i, i2s, cn, cc, si, o2c = carry
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
return i2s, o2c return i2s, o2c
species_keys = species_info[:, 0]
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
# update idx2specie # update idx2specie
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop( _, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
cond_func, cond_func,
body_func, body_func,
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key) (0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
) )
# if there are still some pop genomes not assigned to any species, add them to the last genome # if there are still some pop genomes not assigned to any species, add them to the last genome
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
species_member_counts = vmap(count_members)(jnp.arange(species_size)) species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts) species_info = species_info.at[:, 3].set(species_member_counts)
return idx2specie, center_nodes, center_cons, species_info return idx2specie, center_nodes, center_cons, species_info, next_species_key
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array: def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf) masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr) min_idx = jnp.argmin(masked_arr)

View File

@@ -4,7 +4,8 @@ import configparser
import numpy as np import numpy as np
from algorithms.neat import act_name2func, agg_name2func from algorithms.neat.genome.activations import act_name2func
from algorithms.neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [ jit_config_keys = [

View File

@@ -1,19 +1,18 @@
[basic] [basic]
num_inputs = 2 num_inputs = 2
num_outputs = 1 num_outputs = 1
init_maximum_nodes = 50 maximum_nodes = 50
init_maximum_connections = 50 maximum_connections = 50
init_maximum_species = 10 maximum_species = 10
expand_coe = 1.5
pre_expand_threshold = 0.75
forward_way = "pop" forward_way = "pop"
batch_size = 4 batch_size = 4
random_seed = 0
[population] [population]
fitness_threshold = 100000 fitness_threshold = 3.99999
generation_limit = 1000 generation_limit = 1000
fitness_criterion = "max" fitness_criterion = "max"
pop_size = 50 pop_size = 100000
[genome] [genome]
compatibility_disjoint = 1.0 compatibility_disjoint = 1.0

View File

@@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica
return evaluate(func) return evaluate(func)
def equal(ar1, ar2): def equal(ar1, ar2):
if ar1.shape != ar2.shape: if ar1.shape != ar2.shape:
return False return False

View File

@@ -2,4 +2,4 @@
forward_way = "common" forward_way = "common"
[population] [population]
fitness_threshold = 3.9999 fitness_threshold = 4

View File

@@ -2,7 +2,6 @@ import jax
import numpy as np import numpy as np
from configs import Configer from configs import Configer
from algorithms.neat import Genome
from pipeline import Pipeline from pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
@@ -22,10 +21,10 @@ def evaluate(forward_func):
def main(): def main():
config = Configer.load_config("xor.ini") config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6) pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate) nodes, cons = pipeline.auto_run(evaluate)
g = Genome(nodes, cons, config) # g = Genome(nodes, cons, config)
print(g) # print(g)
if __name__ == '__main__': if __name__ == '__main__':

47
examples/xor3d.ini Normal file
View File

@@ -0,0 +1,47 @@
[basic]
num_inputs = 3
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "common"
batch_size = 4
random_seed = 42
[population]
fitness_threshold = 8
generation_limit = 1000
fitness_criterion = "max"
pop_size = 100000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3
species_elitism = 1
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1

31
examples/xor3d.py Normal file
View File

@@ -0,0 +1,31 @@
import jax
import numpy as np
from configs import Configer
from pipeline import Pipeline
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor3d.ini")
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':
main()

View File

@@ -5,8 +5,8 @@ import numpy as np
import jax import jax
from jax import jit, vmap from jax import jit, vmap
from configs import Configer
from algorithms import neat from algorithms import neat
from configs.configer import Configer
class Pipeline: class Pipeline:
@@ -14,58 +14,40 @@ class Pipeline:
Neat algorithm pipeline. Neat algorithm pipeline.
""" """
def __init__(self, config, seed=42): def __init__(self, config):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config # global config self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions self.jit_config = Configer.create_jit_config(config)
self.P = config['pop_size']
self.N = config['init_maximum_nodes']
self.C = config['init_maximum_connections']
self.S = config['init_maximum_species']
self.generation = 0
self.best_genome = None self.best_genome = None
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config) self.neat_states = neat.initialize(config)
self.species_info = np.full((self.S, 4), np.nan)
self.species_info[0, :] = 0, -np.inf, 0, self.P
self.idx2species = np.zeros(self.P, dtype=np.float32)
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
self.center_cons = np.full((self.S, self.C, 4), np.nan)
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
self.evaluate_time = 0 self.evaluate_time = 0
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
self.forward = neat.create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections)) self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort)) self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
# fitness_lower = np.zeros(self.P, dtype=np.float32) # self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
# randkey_lower = np.zeros(2, dtype=np.uint32) # self.randkey,
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32) # self.pop_nodes,
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32) # self.pop_cons,
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32) # self.species_info,
# idx2species_lower = np.zeros(self.P, dtype=np.float32) # self.idx2species,
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32) # self.center_nodes,
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32) # self.center_cons,
# # self.generation,
# self.tell_func = jit(neat.tell).lower(fitness_lower, # self.next_node_key,
# randkey_lower, # self.next_species_key,
# pop_nodes_lower, # self.jit_config).compile()
# pop_cons_lower,
# species_info_lower,
# idx2species_lower,
# center_nodes_lower,
# center_cons_lower,
# 0,
# self.jit_config).compile()
def ask(self): def ask(self):
""" """
@@ -97,9 +79,19 @@ class Pipeline:
def tell(self, fitness): def tell(self, fitness):
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \ self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons, self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
self.species_info, self.idx2species, self.center_nodes, self.randkey,
self.center_cons, self.generation, self.jit_config) self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
self.jit_config)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']): for _ in range(self.config['generation_limit']):
@@ -109,7 +101,7 @@ class Pipeline:
fitnesses = fitness_func(forward_func) fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!" # assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None: if analysis is not None:
if analysis == "default": if analysis == "default":
@@ -138,7 +130,8 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx]) self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0] member_count = jax.device_get(self.species_info[:, 3])
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.generation}", print(f"Generation: {self.generation}",
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",