perfect! fix bug about jax auto recompile
add task xor-3d
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||
from .population import update_species, create_next_generation, speciate, tell
|
||||
from .population import update_species, create_next_generation, speciate, tell, initialize
|
||||
|
||||
from .genome.activations import act_name2func
|
||||
from .genome.aggregations import agg_name2func
|
||||
|
||||
@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
||||
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
||||
|
||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
|
||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
|
||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
|
||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
pop_cons[:, :p, 0] = grid_a
|
||||
pop_cons[:, :p, 1] = grid_b
|
||||
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
||||
size=(config['pop_size'], p))
|
||||
size=(config['pop_size'], p))
|
||||
pop_cons[:, :p, 3] = 1
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
@@ -1,20 +1,88 @@
|
||||
"""
|
||||
Contains operations on the population: creating the next generation and population speciation.
|
||||
These im.....
|
||||
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
|
||||
P: population size
|
||||
N: maximum number of nodes in any genome
|
||||
C: maximum number of connections in any genome
|
||||
S: maximum number of species in NEAT
|
||||
|
||||
These arrays are used in the algorithm:
|
||||
fitness: Array[(P,), float], the fitness of each individual
|
||||
randkey: Array[2, uint], the random key
|
||||
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
|
||||
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
|
||||
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
|
||||
idx2species: Array[(P,), float], map the individual to its species keys
|
||||
center_nodes: Array[(S, N, 5), float], the center nodes of each species
|
||||
center_cons: Array[(S, C, 4), float], the center connections of each species
|
||||
generation: int, the current generation
|
||||
next_node_key: float, the next of the next node
|
||||
next_species_key: float, the next of the next species
|
||||
jit_config: Configer, the config used in jit-able functions
|
||||
"""
|
||||
|
||||
# TODO: Complete python doc
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
||||
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
|
||||
|
||||
|
||||
def initialize(config):
|
||||
"""
|
||||
initialize the states of NEAT.
|
||||
"""
|
||||
|
||||
P = config['pop_size']
|
||||
N = config['maximum_nodes']
|
||||
C = config['maximum_connections']
|
||||
S = config['maximum_species']
|
||||
|
||||
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||
np.random.seed(config['random_seed'])
|
||||
pop_nodes, pop_cons = initialize_genomes(N, C, config)
|
||||
species_info = np.full((S, 4), np.nan, dtype=np.float32)
|
||||
species_info[0, :] = 0, -np.inf, 0, P
|
||||
idx2species = np.zeros(P, dtype=np.float32)
|
||||
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
|
||||
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
|
||||
center_nodes[0, :, :] = pop_nodes[0, :, :]
|
||||
center_cons[0, :, :] = pop_cons[0, :, :]
|
||||
generation = np.asarray(0, dtype=np.int32)
|
||||
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
|
||||
next_species_key = np.asarray(1, dtype=np.float32)
|
||||
|
||||
return jax.device_put([
|
||||
randkey,
|
||||
pop_nodes,
|
||||
pop_cons,
|
||||
species_info,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_cons,
|
||||
generation,
|
||||
next_node_key,
|
||||
next_species_key,
|
||||
])
|
||||
|
||||
@jit
|
||||
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
def tell(fitness,
|
||||
randkey,
|
||||
pop_nodes,
|
||||
pop_cons,
|
||||
species_info,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_cons,
|
||||
generation,
|
||||
next_node_key,
|
||||
next_species_key,
|
||||
jit_config):
|
||||
|
||||
"""
|
||||
Main update function in NEAT.
|
||||
"""
|
||||
generation += 1
|
||||
|
||||
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||
center_cons, generation, jit_config)
|
||||
|
||||
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, next_node_key, jit_config)
|
||||
|
||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||
elite_mask, generation, jit_config)
|
||||
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
|
||||
|
||||
idx2species, center_nodes, center_cons, species_info = speciate(
|
||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
|
||||
|
||||
|
||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
|
||||
return winner, loser, elite_mask
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
new_node_keys = jnp.arange(pop_size) + generation * pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + next_node_key
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
# update next node key
|
||||
all_nodes_keys = pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||
next_node_key = max_node_key + 1
|
||||
|
||||
return pop_nodes, pop_cons, next_node_key
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((pop_size, ), jnp.inf)
|
||||
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
|
||||
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
||||
current_species_existed = ~jnp.isnan(si[i, 0])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
|
||||
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
|
||||
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, si, o2c, ck)
|
||||
(i, i2s, cn, cc, si, o2c, nsk)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, si, o2c, ck
|
||||
return i + 1, i2s, scn, scc, si, o2c, nsk
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, members_count]
|
||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||
|
||||
# when a new species is created, it needs to be updated, thus do not change i
|
||||
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
|
||||
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, si, o2c, ck = carry
|
||||
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cn, cc, si, o2c, ck
|
||||
return i + 1, i2s, cn, cc, si, o2c, nsk
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, si, o2c = carry
|
||||
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
|
||||
return i2s, o2c
|
||||
|
||||
species_keys = species_info[:, 0]
|
||||
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
||||
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
|
||||
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
|
||||
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
||||
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_info
|
||||
return idx2specie, center_nodes, center_cons, species_info, next_species_key
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
|
||||
Reference in New Issue
Block a user