perfect! fix bug about jax auto recompile
add task xor-3d
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||||
"""
|
"""
|
||||||
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
|
||||||
from .population import update_species, create_next_generation, speciate, tell
|
from .population import update_species, create_next_generation, speciate, tell, initialize
|
||||||
|
|
||||||
from .genome.activations import act_name2func
|
from .genome.activations import act_name2func
|
||||||
from .genome.aggregations import agg_name2func
|
from .genome.aggregations import agg_name2func
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
|||||||
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
||||||
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
||||||
|
|
||||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
|
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
|
||||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
|
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
|
||||||
input_idx = config['input_idx']
|
input_idx = config['input_idx']
|
||||||
output_idx = config['output_idx']
|
output_idx = config['output_idx']
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
|||||||
pop_cons[:, :p, 0] = grid_a
|
pop_cons[:, :p, 0] = grid_a
|
||||||
pop_cons[:, :p, 1] = grid_b
|
pop_cons[:, :p, 1] = grid_b
|
||||||
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
||||||
size=(config['pop_size'], p))
|
size=(config['pop_size'], p))
|
||||||
pop_cons[:, :p, 3] = 1
|
pop_cons[:, :p, 3] = 1
|
||||||
|
|
||||||
return pop_nodes, pop_cons
|
return pop_nodes, pop_cons
|
||||||
|
|||||||
@@ -1,20 +1,88 @@
|
|||||||
"""
|
"""
|
||||||
Contains operations on the population: creating the next generation and population speciation.
|
Contains operations on the population: creating the next generation and population speciation.
|
||||||
These im.....
|
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
|
||||||
|
P: population size
|
||||||
|
N: maximum number of nodes in any genome
|
||||||
|
C: maximum number of connections in any genome
|
||||||
|
S: maximum number of species in NEAT
|
||||||
|
|
||||||
|
These arrays are used in the algorithm:
|
||||||
|
fitness: Array[(P,), float], the fitness of each individual
|
||||||
|
randkey: Array[2, uint], the random key
|
||||||
|
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
|
||||||
|
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
|
||||||
|
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
|
||||||
|
idx2species: Array[(P,), float], map the individual to its species keys
|
||||||
|
center_nodes: Array[(S, N, 5), float], the center nodes of each species
|
||||||
|
center_cons: Array[(S, C, 4), float], the center connections of each species
|
||||||
|
generation: int, the current generation
|
||||||
|
next_node_key: float, the next of the next node
|
||||||
|
next_species_key: float, the next of the next species
|
||||||
|
jit_config: Configer, the config used in jit-able functions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Complete python doc
|
# TODO: Complete python doc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import jax
|
import jax
|
||||||
from jax import jit, vmap, Array, numpy as jnp
|
from jax import jit, vmap, Array, numpy as jnp
|
||||||
|
|
||||||
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
|
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
|
||||||
|
|
||||||
|
|
||||||
|
def initialize(config):
|
||||||
|
"""
|
||||||
|
initialize the states of NEAT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
P = config['pop_size']
|
||||||
|
N = config['maximum_nodes']
|
||||||
|
C = config['maximum_connections']
|
||||||
|
S = config['maximum_species']
|
||||||
|
|
||||||
|
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||||
|
np.random.seed(config['random_seed'])
|
||||||
|
pop_nodes, pop_cons = initialize_genomes(N, C, config)
|
||||||
|
species_info = np.full((S, 4), np.nan, dtype=np.float32)
|
||||||
|
species_info[0, :] = 0, -np.inf, 0, P
|
||||||
|
idx2species = np.zeros(P, dtype=np.float32)
|
||||||
|
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
|
||||||
|
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
|
||||||
|
center_nodes[0, :, :] = pop_nodes[0, :, :]
|
||||||
|
center_cons[0, :, :] = pop_cons[0, :, :]
|
||||||
|
generation = np.asarray(0, dtype=np.int32)
|
||||||
|
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
|
||||||
|
next_species_key = np.asarray(1, dtype=np.float32)
|
||||||
|
|
||||||
|
return jax.device_put([
|
||||||
|
randkey,
|
||||||
|
pop_nodes,
|
||||||
|
pop_cons,
|
||||||
|
species_info,
|
||||||
|
idx2species,
|
||||||
|
center_nodes,
|
||||||
|
center_cons,
|
||||||
|
generation,
|
||||||
|
next_node_key,
|
||||||
|
next_species_key,
|
||||||
|
])
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
def tell(fitness,
|
||||||
|
randkey,
|
||||||
|
pop_nodes,
|
||||||
|
pop_cons,
|
||||||
|
species_info,
|
||||||
|
idx2species,
|
||||||
|
center_nodes,
|
||||||
|
center_cons,
|
||||||
|
generation,
|
||||||
|
next_node_key,
|
||||||
|
next_species_key,
|
||||||
jit_config):
|
jit_config):
|
||||||
|
"""
|
||||||
|
Main update function in NEAT.
|
||||||
|
"""
|
||||||
generation += 1
|
generation += 1
|
||||||
|
|
||||||
k1, k2, randkey = jax.random.split(randkey, 3)
|
k1, k2, randkey = jax.random.split(randkey, 3)
|
||||||
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
|
|||||||
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
update_species(k1, fitness, species_info, idx2species, center_nodes,
|
||||||
center_cons, generation, jit_config)
|
center_cons, generation, jit_config)
|
||||||
|
|
||||||
|
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
||||||
|
elite_mask, next_node_key, jit_config)
|
||||||
|
|
||||||
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
|
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
|
||||||
elite_mask, generation, jit_config)
|
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
|
||||||
|
|
||||||
idx2species, center_nodes, center_cons, species_info = speciate(
|
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
|
||||||
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
|
|
||||||
jit_config)
|
|
||||||
|
|
||||||
|
|
||||||
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||||
"""
|
"""
|
||||||
args:
|
args:
|
||||||
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
|
|||||||
return winner, loser, elite_mask
|
return winner, loser, elite_mask
|
||||||
|
|
||||||
|
|
||||||
@jit
|
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
|
||||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
|
|
||||||
# prepare random keys
|
# prepare random keys
|
||||||
pop_size = pop_nodes.shape[0]
|
pop_size = pop_nodes.shape[0]
|
||||||
new_node_keys = jnp.arange(pop_size) + generation * pop_size
|
new_node_keys = jnp.arange(pop_size) + next_node_key
|
||||||
|
|
||||||
k1, k2 = jax.random.split(rand_key, 2)
|
k1, k2 = jax.random.split(rand_key, 2)
|
||||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||||
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
|
|||||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||||
|
|
||||||
return pop_nodes, pop_cons
|
# update next node key
|
||||||
|
all_nodes_keys = pop_nodes[:, :, 0]
|
||||||
|
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
|
||||||
|
next_node_key = max_node_key + 1
|
||||||
|
|
||||||
|
return pop_nodes, pop_cons, next_node_key
|
||||||
|
|
||||||
|
|
||||||
@jit
|
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
|
||||||
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
|
|
||||||
"""
|
"""
|
||||||
args:
|
args:
|
||||||
pop_nodes: (pop_size, N, 5)
|
pop_nodes: (pop_size, N, 5)
|
||||||
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
|
||||||
|
|
||||||
# the distance between genomes to its center genomes
|
# the distance between genomes to its center genomes
|
||||||
o2c_distances = jnp.full((pop_size, ), jnp.inf)
|
o2c_distances = jnp.full((pop_size,), jnp.inf)
|
||||||
|
|
||||||
# step 1: find new centers
|
# step 1: find new centers
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
|
|
||||||
# part 2: assign members to each species
|
# part 2: assign members to each species
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
|
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
|
||||||
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
|
||||||
current_species_existed = ~jnp.isnan(si[i, 0])
|
current_species_existed = ~jnp.isnan(si[i, 0])
|
||||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||||
not_reach_species_upper_bounds = i < species_size
|
not_reach_species_upper_bounds = i < species_size
|
||||||
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
|
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||||
|
|
||||||
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
|
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
|
||||||
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
jnp.isnan(si[i, 0]), # whether the current species is existing or not
|
||||||
create_new_species, # if not existing, create a new specie
|
create_new_species, # if not existing, create a new specie
|
||||||
update_exist_specie, # if existing, update the specie
|
update_exist_specie, # if existing, update the specie
|
||||||
(i, i2s, cn, cc, si, o2c, ck)
|
(i, i2s, cn, cc, si, o2c, nsk)
|
||||||
)
|
)
|
||||||
|
|
||||||
return i + 1, i2s, scn, scc, si, o2c, ck
|
return i + 1, i2s, scn, scc, si, o2c, nsk
|
||||||
|
|
||||||
def create_new_species(carry):
|
def create_new_species(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry
|
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||||
|
|
||||||
# pick the first one who has not been assigned to any species
|
# pick the first one who has not been assigned to any species
|
||||||
idx = fetch_first(jnp.isnan(i2s))
|
idx = fetch_first(jnp.isnan(i2s))
|
||||||
|
|
||||||
# assign it to the new species
|
# assign it to the new species
|
||||||
# [key, best score, last update generation, members_count]
|
# [key, best score, last update generation, members_count]
|
||||||
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
|
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
|
||||||
i2s = i2s.at[idx].set(ck)
|
i2s = i2s.at[idx].set(nsk)
|
||||||
o2c = o2c.at[idx].set(0)
|
o2c = o2c.at[idx].set(0)
|
||||||
|
|
||||||
# update center genomes
|
# update center genomes
|
||||||
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||||
|
|
||||||
# when a new species is created, it needs to be updated, thus do not change i
|
# when a new species is created, it needs to be updated, thus do not change i
|
||||||
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
|
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
|
||||||
|
|
||||||
def update_exist_specie(carry):
|
def update_exist_specie(carry):
|
||||||
i, i2s, cn, cc, si, o2c, ck = carry
|
i, i2s, cn, cc, si, o2c, nsk = carry
|
||||||
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
|
||||||
|
|
||||||
# turn to next species
|
# turn to next species
|
||||||
return i + 1, i2s, cn, cc, si, o2c, ck
|
return i + 1, i2s, cn, cc, si, o2c, nsk
|
||||||
|
|
||||||
def speciate_by_threshold(carry):
|
def speciate_by_threshold(carry):
|
||||||
i, i2s, cn, cc, si, o2c = carry
|
i, i2s, cn, cc, si, o2c = carry
|
||||||
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
|
|
||||||
return i2s, o2c
|
return i2s, o2c
|
||||||
|
|
||||||
species_keys = species_info[:, 0]
|
|
||||||
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
|
|
||||||
|
|
||||||
|
|
||||||
# update idx2specie
|
# update idx2specie
|
||||||
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
|
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
|
||||||
cond_func,
|
cond_func,
|
||||||
body_func,
|
body_func,
|
||||||
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
|
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||||
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
|
|||||||
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
species_member_counts = vmap(count_members)(jnp.arange(species_size))
|
||||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||||
|
|
||||||
return idx2specie, center_nodes, center_cons, species_info
|
return idx2specie, center_nodes, center_cons, species_info, next_species_key
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||||
min_idx = jnp.argmin(masked_arr)
|
min_idx = jnp.argmin(masked_arr)
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import configparser
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithms.neat import act_name2func, agg_name2func
|
from algorithms.neat.genome.activations import act_name2func
|
||||||
|
from algorithms.neat.genome.aggregations import agg_name2func
|
||||||
|
|
||||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||||
jit_config_keys = [
|
jit_config_keys = [
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
[basic]
|
[basic]
|
||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 50
|
maximum_nodes = 50
|
||||||
init_maximum_connections = 50
|
maximum_connections = 50
|
||||||
init_maximum_species = 10
|
maximum_species = 10
|
||||||
expand_coe = 1.5
|
|
||||||
pre_expand_threshold = 0.75
|
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
random_seed = 0
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 100000
|
fitness_threshold = 3.99999
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 50
|
pop_size = 100000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
@@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica
|
|||||||
return evaluate(func)
|
return evaluate(func)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def equal(ar1, ar2):
|
def equal(ar1, ar2):
|
||||||
if ar1.shape != ar2.shape:
|
if ar1.shape != ar2.shape:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
forward_way = "common"
|
forward_way = "common"
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 3.9999
|
fitness_threshold = 4
|
||||||
@@ -2,7 +2,6 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
from algorithms.neat import Genome
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
@@ -22,10 +21,10 @@ def evaluate(forward_func):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
pipeline = Pipeline(config, seed=6)
|
pipeline = Pipeline(config)
|
||||||
nodes, cons = pipeline.auto_run(evaluate)
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
g = Genome(nodes, cons, config)
|
# g = Genome(nodes, cons, config)
|
||||||
print(g)
|
# print(g)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
47
examples/xor3d.ini
Normal file
47
examples/xor3d.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
[basic]
|
||||||
|
num_inputs = 3
|
||||||
|
num_outputs = 1
|
||||||
|
maximum_nodes = 50
|
||||||
|
maximum_connections = 50
|
||||||
|
maximum_species = 10
|
||||||
|
forward_way = "common"
|
||||||
|
batch_size = 4
|
||||||
|
random_seed = 42
|
||||||
|
|
||||||
|
[population]
|
||||||
|
fitness_threshold = 8
|
||||||
|
generation_limit = 1000
|
||||||
|
fitness_criterion = "max"
|
||||||
|
pop_size = 100000
|
||||||
|
|
||||||
|
[genome]
|
||||||
|
compatibility_disjoint = 1.0
|
||||||
|
compatibility_weight = 0.5
|
||||||
|
conn_add_prob = 0.4
|
||||||
|
conn_add_trials = 1
|
||||||
|
conn_delete_prob = 0
|
||||||
|
node_add_prob = 0.2
|
||||||
|
node_delete_prob = 0
|
||||||
|
|
||||||
|
[species]
|
||||||
|
compatibility_threshold = 3
|
||||||
|
species_elitism = 1
|
||||||
|
max_stagnation = 15
|
||||||
|
genome_elitism = 2
|
||||||
|
survival_threshold = 0.2
|
||||||
|
min_species_size = 1
|
||||||
|
spawn_number_move_rate = 0.5
|
||||||
|
|
||||||
|
[gene-bias]
|
||||||
|
bias_init_mean = 0.0
|
||||||
|
bias_init_std = 1.0
|
||||||
|
bias_mutate_power = 0.5
|
||||||
|
bias_mutate_rate = 0.7
|
||||||
|
bias_replace_rate = 0.1
|
||||||
|
|
||||||
|
[gene-weight]
|
||||||
|
weight_init_mean = 0.0
|
||||||
|
weight_init_std = 1.0
|
||||||
|
weight_mutate_power = 0.5
|
||||||
|
weight_mutate_rate = 0.8
|
||||||
|
weight_replace_rate = 0.1
|
||||||
31
examples/xor3d.py
Normal file
31
examples/xor3d.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from pipeline import Pipeline
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Configer.load_config("xor3d.ini")
|
||||||
|
pipeline = Pipeline(config)
|
||||||
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
|
# g = Genome(nodes, cons, config)
|
||||||
|
# print(g)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
83
pipeline.py
83
pipeline.py
@@ -5,8 +5,8 @@ import numpy as np
|
|||||||
import jax
|
import jax
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
from configs import Configer
|
|
||||||
from algorithms import neat
|
from algorithms import neat
|
||||||
|
from configs.configer import Configer
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -14,58 +14,40 @@ class Pipeline:
|
|||||||
Neat algorithm pipeline.
|
Neat algorithm pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, seed=42):
|
def __init__(self, config):
|
||||||
self.randkey = jax.random.PRNGKey(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
self.config = config # global config
|
self.config = config # global config
|
||||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
self.jit_config = Configer.create_jit_config(config)
|
||||||
|
|
||||||
self.P = config['pop_size']
|
|
||||||
self.N = config['init_maximum_nodes']
|
|
||||||
self.C = config['init_maximum_connections']
|
|
||||||
self.S = config['init_maximum_species']
|
|
||||||
|
|
||||||
self.generation = 0
|
|
||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
|
|
||||||
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
|
self.neat_states = neat.initialize(config)
|
||||||
self.species_info = np.full((self.S, 4), np.nan)
|
|
||||||
self.species_info[0, :] = 0, -np.inf, 0, self.P
|
|
||||||
self.idx2species = np.zeros(self.P, dtype=np.float32)
|
|
||||||
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
|
|
||||||
self.center_cons = np.full((self.S, self.C, 4), np.nan)
|
|
||||||
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
|
|
||||||
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
|
||||||
|
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
self.evaluate_time = 0
|
self.evaluate_time = 0
|
||||||
|
|
||||||
|
|
||||||
|
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||||
|
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
|
||||||
|
|
||||||
|
|
||||||
|
self.forward = neat.create_forward_function(config)
|
||||||
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
|
||||||
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
self.pop_topological_sort = jit(vmap(neat.topological_sort))
|
||||||
self.forward = neat.create_forward_function(config)
|
|
||||||
|
|
||||||
# fitness_lower = np.zeros(self.P, dtype=np.float32)
|
# self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
|
||||||
# randkey_lower = np.zeros(2, dtype=np.uint32)
|
# self.randkey,
|
||||||
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
|
# self.pop_nodes,
|
||||||
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
|
# self.pop_cons,
|
||||||
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
|
# self.species_info,
|
||||||
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
|
# self.idx2species,
|
||||||
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
|
# self.center_nodes,
|
||||||
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
|
# self.center_cons,
|
||||||
#
|
# self.generation,
|
||||||
# self.tell_func = jit(neat.tell).lower(fitness_lower,
|
# self.next_node_key,
|
||||||
# randkey_lower,
|
# self.next_species_key,
|
||||||
# pop_nodes_lower,
|
# self.jit_config).compile()
|
||||||
# pop_cons_lower,
|
|
||||||
# species_info_lower,
|
|
||||||
# idx2species_lower,
|
|
||||||
# center_nodes_lower,
|
|
||||||
# center_cons_lower,
|
|
||||||
# 0,
|
|
||||||
# self.jit_config).compile()
|
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
"""
|
"""
|
||||||
@@ -97,9 +79,19 @@ class Pipeline:
|
|||||||
def tell(self, fitness):
|
def tell(self, fitness):
|
||||||
|
|
||||||
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
|
||||||
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
|
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
|
||||||
self.species_info, self.idx2species, self.center_nodes,
|
self.randkey,
|
||||||
self.center_cons, self.generation, self.jit_config)
|
self.pop_nodes,
|
||||||
|
self.pop_cons,
|
||||||
|
self.species_info,
|
||||||
|
self.idx2species,
|
||||||
|
self.center_nodes,
|
||||||
|
self.center_cons,
|
||||||
|
self.generation,
|
||||||
|
self.next_node_key,
|
||||||
|
self.next_species_key,
|
||||||
|
self.jit_config)
|
||||||
|
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config['generation_limit']):
|
for _ in range(self.config['generation_limit']):
|
||||||
@@ -109,7 +101,7 @@ class Pipeline:
|
|||||||
fitnesses = fitness_func(forward_func)
|
fitnesses = fitness_func(forward_func)
|
||||||
self.evaluate_time += time.time() - tic
|
self.evaluate_time += time.time() - tic
|
||||||
|
|
||||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
# assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||||
|
|
||||||
if analysis is not None:
|
if analysis is not None:
|
||||||
if analysis == "default":
|
if analysis == "default":
|
||||||
@@ -138,7 +130,8 @@ class Pipeline:
|
|||||||
self.best_fitness = fitnesses[max_idx]
|
self.best_fitness = fitnesses[max_idx]
|
||||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||||
|
|
||||||
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0]
|
member_count = jax.device_get(self.species_info[:, 3])
|
||||||
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
|
|
||||||
print(f"Generation: {self.generation}",
|
print(f"Generation: {self.generation}",
|
||||||
f"species: {len(species_sizes)}, {species_sizes}",
|
f"species: {len(species_sizes)}, {species_sizes}",
|
||||||
|
|||||||
Reference in New Issue
Block a user