perfect! fix bug about jax auto recompile

add task xor-3d
This commit is contained in:
wls2002
2023-07-02 22:15:26 +08:00
parent e711146f41
commit c4d34e877b
11 changed files with 234 additions and 104 deletions

View File

@@ -2,7 +2,7 @@
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate, tell
from .population import update_species, create_next_generation, speciate, tell, initialize
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
input_idx = config['input_idx']
output_idx = config['output_idx']
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons

View File

@@ -1,20 +1,88 @@
"""
Contains operations on the population: creating the next generation and population speciation.
These im.....
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
P: population size
N: maximum number of nodes in any genome
C: maximum number of connections in any genome
S: maximum number of species in NEAT
These arrays are used in the algorithm:
fitness: Array[(P,), float], the fitness of each individual
randkey: Array[2, uint], the random key
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
idx2species: Array[(P,), float], map the individual to its species keys
center_nodes: Array[(S, N, 5), float], the center nodes of each species
center_cons: Array[(S, C, 4), float], the center connections of each species
generation: int, the current generation
next_node_key: float, the next of the next node
next_species_key: float, the next of the next species
jit_config: Configer, the config used in jit-able functions
"""
# TODO: Complete python doc
import numpy as np
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
def initialize(config):
"""
initialize the states of NEAT.
"""
P = config['pop_size']
N = config['maximum_nodes']
C = config['maximum_connections']
S = config['maximum_species']
randkey = jax.random.PRNGKey(config['random_seed'])
np.random.seed(config['random_seed'])
pop_nodes, pop_cons = initialize_genomes(N, C, config)
species_info = np.full((S, 4), np.nan, dtype=np.float32)
species_info[0, :] = 0, -np.inf, 0, P
idx2species = np.zeros(P, dtype=np.float32)
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
center_nodes[0, :, :] = pop_nodes[0, :, :]
center_cons[0, :, :] = pop_cons[0, :, :]
generation = np.asarray(0, dtype=np.int32)
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
next_species_key = np.asarray(1, dtype=np.float32)
return jax.device_put([
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
])
@jit
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
def tell(fitness,
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
jit_config):
"""
Main update function in NEAT.
"""
generation += 1
k1, k2, randkey = jax.random.split(randkey, 3)
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, next_node_key, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, generation, jit_config)
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
idx2species, center_nodes, center_cons, species_info = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
return winner, loser, elite_mask
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + generation * pop_size
new_node_keys = jnp.arange(pop_size) + next_node_key
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return pop_nodes, pop_cons, next_node_key
@jit
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size, ), jnp.inf)
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, o2c, ck)
(i, i2s, cn, cc, si, o2c, nsk)
)
return i + 1, i2s, scn, scc, si, o2c, ck
return i + 1, i2s, scn, scc, si, o2c, nsk
def create_new_species(carry):
i, i2s, cn, cc, si, o2c, ck = carry
i, i2s, cn, cc, si, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(ck)
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, si, o2c, ck = carry
i, i2s, cn, cc, si, o2c, nsk = carry
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# turn to next species
return i + 1, i2s, cn, cc, si, o2c, ck
return i + 1, i2s, cn, cc, si, o2c, nsk
def speciate_by_threshold(carry):
i, i2s, cn, cc, si, o2c = carry
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
return i2s, o2c
species_keys = species_info[:, 0]
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts)
return idx2specie, center_nodes, center_cons, species_info
return idx2specie, center_nodes, center_cons, species_info, next_species_key
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)