perfect! fix bug about jax auto recompile
add task xor-3d
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||
from .population import update_species, create_next_generation, speciate, tell
|
||||
from .population import update_species, create_next_generation, speciate, tell, initialize
|
||||
|
||||
from .genome.activations import act_name2func
|
||||
from .genome.aggregations import agg_name2func
|
||||
|
||||
@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
||||
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
||||
|
||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
|
||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
|
||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
|
||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
pop_cons[:, :p, 0] = grid_a
|
||||
pop_cons[:, :p, 1] = grid_b
|
||||
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
||||
size=(config['pop_size'], p))
|
||||
size=(config['pop_size'], p))
|
||||
pop_cons[:, :p, 3] = 1
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
@@ -1,20 +1,88 @@
|
||||
"""
|
||||
Contains operations on the population: creating the next generation and population speciation.
|
||||
These im.....
|
||||
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
|
||||
P: population size
|
||||
N: maximum number of nodes in any genome
|
||||
C: maximum number of connections in any genome
|
||||
S: maximum number of species in NEAT
|
||||
|
||||
These arrays are used in the algorithm:
|
||||
fitness: Array[(P,), float], the fitness of each individual
|
||||
randkey: Array[2, uint], the random key
|
||||
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
|
||||
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
|
||||
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
|
||||
idx2species: Array[(P,), float], map the individual to its species keys
|
||||
center_nodes: Array[(S, N, 5), float], the center nodes of each species
|
||||
center_cons: Array[(S, C, 4), float], the center connections of each species
|
||||
generation: int, the current generation
|
||||
next_node_key: float, the next of the next node
|
||||
next_species_key: float, the next of the next species
|
||||
jit_config: Configer, the config used in jit-able functions
|
||||
"""
|
||||
|
||||
# TODO: Complete python doc
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
|
||||
|
||||
|
||||
def initialize(config):
|
||||
"""
|
||||
initialize the states of NEAT.
|
||||
"""
|
||||
|
||||
P = config['pop_size']
|
||||
N = config['maximum_nodes']
|
||||
C = config['maximum_connections']
|
||||
S = config['maximum_species']
|
||||
|
||||
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||
np.random.seed(config['random_seed'])
|
||||
pop_nodes, pop_cons = initialize_genomes(N, C, config)
|
||||
species_info = np.full((S, 4), np.nan, dtype=np.float32)
|
||||
species_info[0, :] = 0, -np.inf, 0, P
|
||||
idx2species = np.zeros(P, dtype=np.float32)
|
||||
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
|
||||
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
|
||||
center_nodes[0, :, :] = pop_nodes[0, :, :]
|
||||
center_cons[0, :, :] = pop_cons[0, :, :]
|
||||
generation = np.asarray(0, dtype=np.int32)
|
||||
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
|
||||
next_species_key = np.asarray(1, dtype=np.float32)
|
||||
|
||||
return jax.device_put([
|
||||
randkey,
|
||||
pop_nodes,
|
||||
pop_cons,
|
||||
species_info,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_cons,
|
||||
generation,
|
||||
next_node_key,
|
||||
next_species_key,
|
||||
])
|
||||
|
||||
@jit
|
||||
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
def tell(fitness,
|
||||
randkey,
|
||||
pop_nodes,
|
||||
pop_cons,
|
||||
species_info,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_cons,
|
||||
generation,
|
||||
next_node_key,
|
||||
next_species_key,
|
||||
jit_config):
|
||||
|
||||
"""
|
||||
Main update function in NEAT.
|
||||
"""
|
||||
generation += 1
|
||||
|
||||
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||
center_cons, generation, jit_config)
|
||||
|
||||
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, next_node_key, jit_config)
|
||||
|
||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, generation, jit_config)
|
||||
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
|
||||
|
||||
idx2species, center_nodes, center_cons, species_info = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
|
||||
|
||||
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
|
||||
return winner, loser, elite_mask
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
new_node_keys = jnp.arange(pop_size) + generation * pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + next_node_key
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return pop_nodes, pop_cons, next_node_key
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size, ), jnp.inf)
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
|
||||
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
||||
current_species_existed = ~jnp.isnan(si[i, 0])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
|
||||
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, si, o2c, ck)
|
||||
(i, i2s, cn, cc, si, o2c, nsk)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, si, o2c, ck
|
||||
return i + 1, i2s, scn, scc, si, o2c, nsk
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, members_count]
|
||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
|
||||
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cn, cc, si, o2c, ck
|
||||
return i + 1, i2s, cn, cc, si, o2c, nsk
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, si, o2c = carry
|
||||
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
|
||||
return i2s, o2c
|
||||
|
||||
species_keys = species_info[:, 0]
|
||||
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
||||
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
|
||||
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
|
||||
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_info
|
||||
return idx2specie, center_nodes, center_cons, species_info, next_species_key
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
|
||||
@@ -4,7 +4,8 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from algorithms.neat import act_name2func, agg_name2func
|
||||
from algorithms.neat.genome.activations import act_name2func
|
||||
from algorithms.neat.genome.aggregations import agg_name2func
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
init_maximum_nodes = 50
|
||||
init_maximum_connections = 50
|
||||
init_maximum_species = 10
|
||||
expand_coe = 1.5
|
||||
pre_expand_threshold = 0.75
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
random_seed = 0
|
||||
|
||||
[population]
|
||||
fitness_threshold = 100000
|
||||
fitness_threshold = 3.99999
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 50
|
||||
pop_size = 100000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
|
||||
@@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica
|
||||
return evaluate(func)
|
||||
|
||||
|
||||
|
||||
|
||||
def equal(ar1, ar2):
|
||||
if ar1.shape != ar2.shape:
|
||||
return False
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
forward_way = "common"
|
||||
|
||||
[population]
|
||||
fitness_threshold = 3.9999
|
||||
fitness_threshold = 4
|
||||
@@ -2,7 +2,6 @@ import jax
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import Genome
|
||||
from pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
@@ -22,10 +21,10 @@ def evaluate(forward_func):
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
pipeline = Pipeline(config)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
g = Genome(nodes, cons, config)
|
||||
print(g)
|
||||
# g = Genome(nodes, cons, config)
|
||||
# print(g)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
47
examples/xor3d.ini
Normal file
47
examples/xor3d.ini
Normal file
@@ -0,0 +1,47 @@
|
||||
[basic]
|
||||
num_inputs = 3
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "common"
|
||||
batch_size = 4
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
fitness_threshold = 8
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 100000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3
|
||||
species_elitism = 1
|
||||
max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_move_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
bias_init_mean = 0.0
|
||||
bias_init_std = 1.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
[gene-weight]
|
||||
weight_init_mean = 0.0
|
||||
weight_init_std = 1.0
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
31
examples/xor3d.py
Normal file
31
examples/xor3d.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor3d.ini")
|
||||
pipeline = Pipeline(config)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
# g = Genome(nodes, cons, config)
|
||||
# print(g)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
83
pipeline.py
83
pipeline.py
@@ -5,8 +5,8 @@ import numpy as np
|
||||
import jax
|
||||
from jax import jit, vmap
|
||||
|
||||
from configs import Configer
|
||||
from algorithms import neat
|
||||
from configs.configer import Configer
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -14,58 +14,40 @@ class Pipeline:
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.jit_config = Configer.create_jit_config(config)
|
||||
|
||||
self.P = config['pop_size']
|
||||
self.N = config['init_maximum_nodes']
|
||||
self.C = config['init_maximum_connections']
|
||||
self.S = config['init_maximum_species']
|
||||
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
||||
self.species_info = np.full((self.S, 4), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0, self.P
|
||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
||||
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
||||
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
||||
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
|
||||
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
||||
self.neat_states = neat.initialize(config)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
|
||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
|
||||
|
||||
|
||||
self.forward = neat.create_forward_function(config)
|
||||
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||
self.forward = neat.create_forward_function(config)
|
||||
|
||||
# fitness_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# randkey_lower = np.zeros(2, dtype=np.uint32)
|
||||
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
|
||||
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
|
||||
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
|
||||
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
|
||||
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
|
||||
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
|
||||
#
|
||||
# self.tell_func = jit(neat.tell).lower(fitness_lower,
|
||||
# randkey_lower,
|
||||
# pop_nodes_lower,
|
||||
# pop_cons_lower,
|
||||
# species_info_lower,
|
||||
# idx2species_lower,
|
||||
# center_nodes_lower,
|
||||
# center_cons_lower,
|
||||
# 0,
|
||||
# self.jit_config).compile()
|
||||
# self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
|
||||
# self.randkey,
|
||||
# self.pop_nodes,
|
||||
# self.pop_cons,
|
||||
# self.species_info,
|
||||
# self.idx2species,
|
||||
# self.center_nodes,
|
||||
# self.center_cons,
|
||||
# self.generation,
|
||||
# self.next_node_key,
|
||||
# self.next_species_key,
|
||||
# self.jit_config).compile()
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
@@ -97,9 +79,19 @@ class Pipeline:
|
||||
def tell(self, fitness):
|
||||
|
||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
|
||||
self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
|
||||
self.randkey,
|
||||
self.pop_nodes,
|
||||
self.pop_cons,
|
||||
self.species_info,
|
||||
self.idx2species,
|
||||
self.center_nodes,
|
||||
self.center_cons,
|
||||
self.generation,
|
||||
self.next_node_key,
|
||||
self.next_species_key,
|
||||
self.jit_config)
|
||||
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
@@ -109,7 +101,7 @@ class Pipeline:
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
# assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
@@ -138,7 +130,8 @@ class Pipeline:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0]
|
||||
member_count = jax.device_get(self.species_info[:, 3])
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
|
||||
Reference in New Issue
Block a user