perfect! fix bug about jax auto recompile

add task xor-3d
This commit is contained in:
wls2002
2023-07-02 22:15:26 +08:00
parent e711146f41
commit c4d34e877b
11 changed files with 234 additions and 104 deletions

View File

@@ -2,7 +2,7 @@
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward_function, topological_sort, unflatten_connections, initialize_genomes
from .population import update_species, create_next_generation, speciate, tell
from .population import update_species, create_next_generation, speciate, tell, initialize
from .genome.activations import act_name2func
from .genome.aggregations import agg_name2func

View File

@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
input_idx = config['input_idx']
output_idx = config['output_idx']

View File

@@ -1,20 +1,88 @@
"""
Contains operations on the population: creating the next generation and population speciation.
These im.....
The value tuple (P, N, C, S) is determined when the algorithm is initialized.
P: population size
N: maximum number of nodes in any genome
C: maximum number of connections in any genome
S: maximum number of species in NEAT
These arrays are used in the algorithm:
fitness: Array[(P,), float], the fitness of each individual
randkey: Array[2, uint], the random key
pop_nodes: Array[(P, N, 5), float], nodes part of the population. [key, bias, response, act, agg]
pop_cons: Array[(P, C, 4), float], connections part of the population. [in_node, out_node, weight, enabled]
species_info: Array[(S, 4), float], the information of each species. [key, best_score, last_update, members_count]
idx2species: Array[(P,), float], map the individual to its species keys
center_nodes: Array[(S, N, 5), float], the center nodes of each species
center_cons: Array[(S, C, 4), float], the center connections of each species
generation: int, the current generation
next_node_key: float, the next of the next node
next_species_key: float, the next of the next species
jit_config: Configer, the config used in jit-able functions
"""
# TODO: Complete python doc
import numpy as np
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover, I_INT, fetch_first, rank_elements
from .genome import initialize_genomes, distance, mutate, crossover, fetch_first, rank_elements
def initialize(config):
"""
initialize the states of NEAT.
"""
P = config['pop_size']
N = config['maximum_nodes']
C = config['maximum_connections']
S = config['maximum_species']
randkey = jax.random.PRNGKey(config['random_seed'])
np.random.seed(config['random_seed'])
pop_nodes, pop_cons = initialize_genomes(N, C, config)
species_info = np.full((S, 4), np.nan, dtype=np.float32)
species_info[0, :] = 0, -np.inf, 0, P
idx2species = np.zeros(P, dtype=np.float32)
center_nodes = np.full((S, N, 5), np.nan, dtype=np.float32)
center_cons = np.full((S, C, 4), np.nan, dtype=np.float32)
center_nodes[0, :, :] = pop_nodes[0, :, :]
center_cons[0, :, :] = pop_cons[0, :, :]
generation = np.asarray(0, dtype=np.int32)
next_node_key = np.asarray(config['num_inputs'] + config['num_outputs'], dtype=np.float32)
next_species_key = np.asarray(1, dtype=np.float32)
return jax.device_put([
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
])
@jit
def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
def tell(fitness,
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key,
jit_config):
"""
Main update function in NEAT.
"""
generation += 1
k1, k2, randkey = jax.random.split(randkey, 3)
@@ -23,19 +91,15 @@ def tell(fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, cente
update_species(k1, fitness, species_info, idx2species, center_nodes,
center_cons, generation, jit_config)
pop_nodes, pop_cons, next_node_key = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, next_node_key, jit_config)
pop_nodes, pop_cons = create_next_generation(k2, pop_nodes, pop_cons, winner, loser,
elite_mask, generation, jit_config)
idx2species, center_nodes, center_cons, species_info, next_species_key = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config)
idx2species, center_nodes, center_cons, species_info = speciate(
pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation,
jit_config)
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation, next_node_key, next_species_key
return randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
@@ -199,11 +263,10 @@ def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitn
return winner, loser, elite_mask
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, generation, jit_config):
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, next_node_key, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
new_node_keys = jnp.arange(pop_size) + generation * pop_size
new_node_keys = jnp.arange(pop_size) + next_node_key
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -222,11 +285,15 @@ def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_m
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return pop_nodes, pop_cons, next_node_key
@jit
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, jit_config):
def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, generation, next_species_key, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
@@ -243,7 +310,7 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
idx2specie = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size, ), jnp.inf)
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
@@ -277,35 +344,35 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # si is short for species_info, ck is short for current key
i, i2s, cn, cc, si, o2c, nsk = carry # si is short for species_info, nsk is short for next_species_key
# jax.debug.print("i:\n{}\ni2s:\n{}\nsi:\n{}", i, i2s, si)
current_species_existed = ~jnp.isnan(si[i, 0])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return current_species_existed | (not_all_assigned & not_reach_species_upper_bounds)
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cn, cc, si, o2c, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i, i2s, cn, cc, si, o2c, nsk = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
_, i2s, scn, scc, si, o2c, ck = jax.lax.cond(
_, i2s, scn, scc, si, o2c, nsk = jax.lax.cond(
jnp.isnan(si[i, 0]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, si, o2c, ck)
(i, i2s, cn, cc, si, o2c, nsk)
)
return i + 1, i2s, scn, scc, si, o2c, ck
return i + 1, i2s, scn, scc, si, o2c, nsk
def create_new_species(carry):
i, i2s, cn, cc, si, o2c, ck = carry
i, i2s, cn, cc, si, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
si = si.at[i].set(jnp.array([ck, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(ck)
si = si.at[i].set(jnp.array([nsk, -jnp.inf, generation, 0]))
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
@@ -315,14 +382,14 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cn, cc, si, o2c, ck + 1 # change to next new speciate key
return i + 1, i2s, cn, cc, si, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, si, o2c, ck = carry
i, i2s, cn, cc, si, o2c, nsk = carry
i2s, o2c = speciate_by_threshold((i, i2s, cn, cc, si, o2c))
# turn to next species
return i + 1, i2s, cn, cc, si, o2c, ck
return i + 1, i2s, cn, cc, si, o2c, nsk
def speciate_by_threshold(carry):
i, i2s, cn, cc, si, o2c = carry
@@ -344,15 +411,11 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
return i2s, o2c
species_keys = species_info[:, 0]
current_new_key = jnp.max(jnp.where(jnp.isnan(species_keys), -jnp.inf, species_keys)) + 1
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_info, _, _ = jax.lax.while_loop(
_, idx2specie, center_nodes, center_cons, species_info, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, current_new_key)
(0, idx2specie, center_nodes, center_cons, species_info, o2c_distances, next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
@@ -369,10 +432,9 @@ def speciate(pop_nodes, pop_cons, species_info, center_nodes, center_cons, gener
species_member_counts = vmap(count_members)(jnp.arange(species_size))
species_info = species_info.at[:, 3].set(species_member_counts)
return idx2specie, center_nodes, center_cons, species_info
return idx2specie, center_nodes, center_cons, species_info, next_species_key
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)

View File

@@ -4,7 +4,8 @@ import configparser
import numpy as np
from algorithms.neat import act_name2func, agg_name2func
from algorithms.neat.genome.activations import act_name2func
from algorithms.neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [

View File

@@ -1,19 +1,18 @@
[basic]
num_inputs = 2
num_outputs = 1
init_maximum_nodes = 50
init_maximum_connections = 50
init_maximum_species = 10
expand_coe = 1.5
pre_expand_threshold = 0.75
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "pop"
batch_size = 4
random_seed = 0
[population]
fitness_threshold = 100000
fitness_threshold = 3.99999
generation_limit = 1000
fitness_criterion = "max"
pop_size = 50
pop_size = 100000
[genome]
compatibility_disjoint = 1.0

View File

@@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica
return evaluate(func)
def equal(ar1, ar2):
if ar1.shape != ar2.shape:
return False

View File

@@ -2,4 +2,4 @@
forward_way = "common"
[population]
fitness_threshold = 3.9999
fitness_threshold = 4

View File

@@ -2,7 +2,6 @@ import jax
import numpy as np
from configs import Configer
from algorithms.neat import Genome
from pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
@@ -22,10 +21,10 @@ def evaluate(forward_func):
def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6)
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
g = Genome(nodes, cons, config)
print(g)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':

47
examples/xor3d.ini Normal file
View File

@@ -0,0 +1,47 @@
[basic]
num_inputs = 3
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "common"
batch_size = 4
random_seed = 42
[population]
fitness_threshold = 8
generation_limit = 1000
fitness_criterion = "max"
pop_size = 100000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3
species_elitism = 1
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1

31
examples/xor3d.py Normal file
View File

@@ -0,0 +1,31 @@
import jax
import numpy as np
from configs import Configer
from pipeline import Pipeline
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor3d.ini")
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':
main()

View File

@@ -5,8 +5,8 @@ import numpy as np
import jax
from jax import jit, vmap
from configs import Configer
from algorithms import neat
from configs.configer import Configer
class Pipeline:
@@ -14,57 +14,39 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
def __init__(self, config):
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.jit_config = Configer.create_jit_config(config)
self.P = config['pop_size']
self.N = config['init_maximum_nodes']
self.C = config['init_maximum_connections']
self.S = config['init_maximum_species']
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = neat.initialize_genomes(self.N, self.C, self.config)
self.species_info = np.full((self.S, 4), np.nan)
self.species_info[0, :] = 0, -np.inf, 0, self.P
self.idx2species = np.zeros(self.P, dtype=np.float32)
self.center_nodes = np.full((self.S, self.N, 5), np.nan)
self.center_cons = np.full((self.S, self.C, 4), np.nan)
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
self.neat_states = neat.initialize(config)
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
self.evaluate_time = 0
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.initialize(config)
self.forward = neat.create_forward_function(config)
self.pop_unflatten_connections = jit(vmap(neat.unflatten_connections))
self.pop_topological_sort = jit(vmap(neat.topological_sort))
self.forward = neat.create_forward_function(config)
# fitness_lower = np.zeros(self.P, dtype=np.float32)
# randkey_lower = np.zeros(2, dtype=np.uint32)
# pop_nodes_lower = np.zeros((self.P, self.N, 5), dtype=np.float32)
# pop_cons_lower = np.zeros((self.P, self.C, 4), dtype=np.float32)
# species_info_lower = np.zeros((self.S, 4), dtype=np.float32)
# idx2species_lower = np.zeros(self.P, dtype=np.float32)
# center_nodes_lower = np.zeros((self.S, self.N, 5), dtype=np.float32)
# center_cons_lower = np.zeros((self.S, self.C, 4), dtype=np.float32)
#
# self.tell_func = jit(neat.tell).lower(fitness_lower,
# randkey_lower,
# pop_nodes_lower,
# pop_cons_lower,
# species_info_lower,
# idx2species_lower,
# center_nodes_lower,
# center_cons_lower,
# 0,
# self.tell_func = neat.tell.lower(np.zeros(config['pop_size'], dtype=np.float32),
# self.randkey,
# self.pop_nodes,
# self.pop_cons,
# self.species_info,
# self.idx2species,
# self.center_nodes,
# self.center_cons,
# self.generation,
# self.next_node_key,
# self.next_species_key,
# self.jit_config).compile()
def ask(self):
@@ -97,9 +79,19 @@ class Pipeline:
def tell(self, fitness):
self.randkey, self.pop_nodes, self.pop_cons, self.species_info, self.idx2species, self.center_nodes, \
self.center_cons, self.generation = neat.tell(fitness, self.randkey, self.pop_nodes, self.pop_cons,
self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
self.center_cons, self.generation, self.next_node_key, self.next_species_key = neat.tell(fitness,
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
self.jit_config)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
@@ -109,7 +101,7 @@ class Pipeline:
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
# assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
@@ -138,7 +130,8 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
species_sizes = [int(i) for i in self.species_info[:, 3] if i > 0]
member_count = jax.device_get(self.species_info[:, 3])
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.generation}",
f"species: {len(species_sizes)}, {species_sizes}",