modifying
This commit is contained in:
6
algorithms/neat/__init__.py
Normal file
6
algorithms/neat/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||||
|
"""
|
||||||
|
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
|
||||||
|
from .operations import create_next_generation_then_speciate
|
||||||
|
from .species import SpeciesController
|
||||||
0
algorithms/neat/genome/debug/__init__.py
Normal file
0
algorithms/neat/genome/debug/__init__.py
Normal file
@@ -5,8 +5,11 @@ from jax import jit, vmap
|
|||||||
from .utils import I_INT
|
from .utils import I_INT
|
||||||
|
|
||||||
|
|
||||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
|
||||||
def create_forward(config):
|
def create_forward(config):
|
||||||
|
"""
|
||||||
|
meta method to create forward function
|
||||||
|
"""
|
||||||
|
|
||||||
def act(idx, z):
|
def act(idx, z):
|
||||||
"""
|
"""
|
||||||
calculate activation function for each node
|
calculate activation function for each node
|
||||||
@@ -4,11 +4,11 @@ Only used in feed-forward networks.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import jit, vmap, Array
|
from jax import jit, Array
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
# from .configs import fetch_first, I_INT
|
# from .configs import fetch_first, I_INT
|
||||||
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
|
from algorithms.neat.genome.utils import fetch_first, I_INT
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp, Array
|
from jax import numpy as jnp, Array
|
||||||
@@ -30,6 +32,7 @@ def unflatten_connections(nodes: Array, cons: Array):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def key_to_indices(key, keys):
|
def key_to_indices(key, keys):
|
||||||
return fetch_first(key == keys)
|
return fetch_first(key == keys)
|
||||||
|
|
||||||
@@ -56,4 +59,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
|||||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||||
return fetch_first(mask, default)
|
return fetch_first(mask, default)
|
||||||
|
@partial(jit, static_argnames=['reverse'])
|
||||||
|
def rank_elements(array, reverse=False):
|
||||||
|
"""
|
||||||
|
rank the element in the array.
|
||||||
|
if reverse is True, the rank is from large to small.
|
||||||
|
"""
|
||||||
|
if reverse:
|
||||||
|
array = -array
|
||||||
|
return jnp.argsort(jnp.argsort(array))
|
||||||
160
algorithms/neat/jit_species.py
Normal file
160
algorithms/neat/jit_species.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import jit, numpy as jnp, vmap
|
||||||
|
|
||||||
|
from .genome.utils import rank_elements
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||||
|
"""
|
||||||
|
args:
|
||||||
|
randkey: random key
|
||||||
|
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||||
|
species_keys: Array[(species_size, 3), float], the information of each species
|
||||||
|
[species_key, best_score, last_update]
|
||||||
|
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||||
|
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||||
|
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||||
|
generation: int, current generation
|
||||||
|
jit_config: Dict, the configuration of jit functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
# update the fitness of each species
|
||||||
|
species_fitness = update_species_fitness(species_info, idx2species, fitness)
|
||||||
|
|
||||||
|
# stagnation species
|
||||||
|
species_fitness, species_info, center_nodes, center_cons = \
|
||||||
|
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
|
||||||
|
|
||||||
|
# sort species_info by their fitness. (push nan to the end)
|
||||||
|
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||||
|
species_info = species_info[sort_indices]
|
||||||
|
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
|
||||||
|
|
||||||
|
# decide the number of members of each species by their fitness
|
||||||
|
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||||
|
|
||||||
|
# crossover info
|
||||||
|
winner, loser, elite_mask = \
|
||||||
|
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||||
|
|
||||||
|
jax.debug.print("{}, {}", fitness, winner)
|
||||||
|
jax.debug.print("{}", fitness[winner])
|
||||||
|
|
||||||
|
return species_info, center_nodes, center_cons, winner, loser, elite_mask
|
||||||
|
|
||||||
|
|
||||||
|
def update_species_fitness(species_info, idx2species, fitness):
|
||||||
|
"""
|
||||||
|
obtain the fitness of the species by the fitness of each individual.
|
||||||
|
use max criterion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def aux_func(idx):
|
||||||
|
species_key = species_info[idx, 0]
|
||||||
|
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
|
||||||
|
f = jnp.max(s_fitness)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
|
||||||
|
"""
|
||||||
|
stagnation species.
|
||||||
|
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||||
|
elitism species never stagnation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def aux_func(idx):
|
||||||
|
s_fitness = species_fitness[idx]
|
||||||
|
species_key, best_score, last_update = species_info[idx]
|
||||||
|
# stagnation condition
|
||||||
|
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||||
|
|
||||||
|
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||||
|
|
||||||
|
# elite species will not be stagnation
|
||||||
|
species_rank = rank_elements(species_fitness)
|
||||||
|
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
|
||||||
|
|
||||||
|
# set stagnation species to nan
|
||||||
|
species_info = jnp.where(st[:, None], jnp.nan, species_info)
|
||||||
|
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
|
||||||
|
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
|
||||||
|
species_fitness = jnp.where(st, jnp.nan, species_fitness)
|
||||||
|
|
||||||
|
return species_fitness, species_info, center_nodes, center_cons
|
||||||
|
|
||||||
|
|
||||||
|
def cal_spawn_numbers(species_info, jit_config):
|
||||||
|
"""
|
||||||
|
decide the number of members of each species by their fitness rank.
|
||||||
|
the species with higher fitness will have more members
|
||||||
|
Linear ranking selection
|
||||||
|
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_species_valid = ~jnp.isnan(species_info[:, 0])
|
||||||
|
valid_species_num = jnp.sum(is_species_valid)
|
||||||
|
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||||
|
|
||||||
|
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
|
||||||
|
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||||
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||||
|
|
||||||
|
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||||
|
|
||||||
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
|
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||||
|
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||||
|
|
||||||
|
return spawn_number
|
||||||
|
|
||||||
|
|
||||||
|
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||||
|
|
||||||
|
species_size = species_info.shape[0]
|
||||||
|
pop_size = fitness.shape[0]
|
||||||
|
s_idx = jnp.arange(species_size)
|
||||||
|
p_idx = jnp.arange(pop_size)
|
||||||
|
|
||||||
|
def aux_func(key, idx):
|
||||||
|
members = idx2species == species_info[idx, 0]
|
||||||
|
members_num = jnp.sum(members)
|
||||||
|
|
||||||
|
members_fitness = jnp.where(members, fitness, jnp.nan)
|
||||||
|
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||||
|
|
||||||
|
elite_size = jit_config['genome_elitism']
|
||||||
|
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
|
||||||
|
|
||||||
|
select_pro = (p_idx < survive_size) / survive_size
|
||||||
|
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
||||||
|
|
||||||
|
# elite
|
||||||
|
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
||||||
|
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
||||||
|
elite = jnp.where(p_idx < elite_size, True, False)
|
||||||
|
return fa, ma, elite
|
||||||
|
|
||||||
|
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
||||||
|
|
||||||
|
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||||
|
|
||||||
|
def aux_func(idx):
|
||||||
|
loc = jnp.argmax(idx < spawn_number_cum)
|
||||||
|
|
||||||
|
# elite genomes are at the beginning of the species
|
||||||
|
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||||
|
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||||
|
|
||||||
|
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||||
|
|
||||||
|
is_part1_win = fitness[part1] >= fitness[part2]
|
||||||
|
winner = jnp.where(is_part1_win, part1, part2)
|
||||||
|
loser = jnp.where(is_part1_win, part2, part1)
|
||||||
|
|
||||||
|
return winner, loser, elite_mask
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
contains operations on the population: creating the next generation and population speciation.
|
contains operations on the population: creating the next generation and population speciation.
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
from jax import jit, vmap, Array, numpy as jnp
|
||||||
from jax import jit, vmap
|
|
||||||
|
|
||||||
from jax import Array
|
|
||||||
|
|
||||||
from .genome import distance, mutate, crossover
|
from .genome import distance, mutate, crossover
|
||||||
from .genome.utils import I_INT, fetch_first
|
from .genome.utils import I_INT, fetch_first
|
||||||
@@ -8,7 +8,6 @@ See
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
@@ -4,8 +4,8 @@ import configparser
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from neat.genome.activations import act_name2func
|
from algorithms.neat.genome.activations import act_name2func
|
||||||
from neat.genome.aggregations import agg_name2func
|
from algorithms.neat.genome.aggregations import agg_name2func
|
||||||
|
|
||||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||||
jit_config_keys = [
|
jit_config_keys = [
|
||||||
@@ -41,6 +41,11 @@ jit_config_keys = [
|
|||||||
"weight_mutate_rate",
|
"weight_mutate_rate",
|
||||||
"weight_replace_rate",
|
"weight_replace_rate",
|
||||||
"enable_mutate_rate",
|
"enable_mutate_rate",
|
||||||
|
"max_stagnation",
|
||||||
|
"pop_size",
|
||||||
|
"genome_elitism",
|
||||||
|
"survival_threshold",
|
||||||
|
"species_elitism"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
[basic]
|
[basic]
|
||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
init_maximum_nodes = 50
|
init_maximum_nodes = 200
|
||||||
init_maximum_connections = 50
|
init_maximum_connections = 200
|
||||||
init_maximum_species = 10
|
init_maximum_species = 10
|
||||||
expand_coe = 2.0
|
expand_coe = 1.5
|
||||||
|
pre_expand_threshold = 0.75
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
@@ -12,7 +13,7 @@ batch_size = 4
|
|||||||
fitness_threshold = 100000
|
fitness_threshold = 100000
|
||||||
generation_limit = 100
|
generation_limit = 100
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 15000
|
pop_size = 150
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax
|
|
||||||
|
|
||||||
a = {1:2, 2:3, 4:5}
|
|
||||||
print(a.values())
|
|
||||||
|
|
||||||
a = jnp.array([1, 0, 1, 0, np.nan])
|
|
||||||
b = jnp.array([1, 1, 1, 1, 1])
|
|
||||||
c = jnp.array([1, 1, 1, 1, 1])
|
|
||||||
|
|
||||||
full = jnp.array([
|
|
||||||
[1, 1, 1],
|
|
||||||
[0, 1, 1],
|
|
||||||
[1, 1, 1],
|
|
||||||
[0, 1, 1],
|
|
||||||
])
|
|
||||||
|
|
||||||
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
|
|
||||||
|
|
||||||
aux0 = full[:, 0, None]
|
|
||||||
aux1 = full[:, 1, None]
|
|
||||||
|
|
||||||
print(aux0, aux0.shape)
|
|
||||||
|
|
||||||
print(jnp.concatenate([aux0, aux1], axis=1))
|
|
||||||
|
|
||||||
f_a = jnp.array([False, False, True, True])
|
|
||||||
f_b = jnp.array([True, False, False, False])
|
|
||||||
|
|
||||||
print(jnp.logical_and(f_a, f_b))
|
|
||||||
print(f_a & f_b)
|
|
||||||
|
|
||||||
print(f_a + jnp.nan * 0.0)
|
|
||||||
print(f_a + 1 * 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def main():
|
|
||||||
return func('happy') + func('sad')
|
|
||||||
|
|
||||||
|
|
||||||
def func(x):
|
|
||||||
if x == 'happy':
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return 2
|
|
||||||
|
|
||||||
a = jnp.zeros((3, 3))
|
|
||||||
print(a.dtype)
|
|
||||||
|
|
||||||
c = None
|
|
||||||
b = 1 or c
|
|
||||||
print(b)
|
|
||||||
26
examples/evox_test.py
Normal file
26
examples/evox_test.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import jax
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
from evox import algorithms, problems, pipelines
|
||||||
|
from evox.monitors import StdSOMonitor
|
||||||
|
|
||||||
|
monitor = StdSOMonitor()
|
||||||
|
|
||||||
|
pso = algorithms.PSO(
|
||||||
|
lb=jnp.full(shape=(2,), fill_value=-32),
|
||||||
|
ub=jnp.full(shape=(2,), fill_value=32),
|
||||||
|
pop_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
ackley = problems.classic.Ackley()
|
||||||
|
|
||||||
|
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
state = pipeline.init(key)
|
||||||
|
|
||||||
|
# run the pipeline for 100 steps
|
||||||
|
for i in range(100):
|
||||||
|
state = pipeline.step(state)
|
||||||
|
|
||||||
|
print(monitor.get_min_fitness())
|
||||||
@@ -1,27 +1,18 @@
|
|||||||
import numpy as np
|
from functools import partial
|
||||||
from jax import jit
|
|
||||||
|
|
||||||
from configs import Configer
|
import jax
|
||||||
from neat.pipeline import Pipeline
|
from jax import numpy as jnp, jit
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
@partial(jit, static_argnames=['reverse'])
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
def rank_element(array, reverse=False):
|
||||||
|
"""
|
||||||
def main():
|
rank the element in the array.
|
||||||
config = Configer.load_config("xor.ini")
|
if reverse is True, the rank is from large to small.
|
||||||
print(config)
|
"""
|
||||||
pipeline = Pipeline(config)
|
if reverse:
|
||||||
forward_func = pipeline.ask()
|
array = -array
|
||||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
return jnp.argsort(jnp.argsort(array))
|
||||||
outputs = forward_func(xor_inputs)
|
|
||||||
print(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
|
a = jnp.array([1 ,5, 3, 5, 2, 1, 0])
|
||||||
@jit
|
print(rank_element(a, reverse=True))
|
||||||
def f(x, jit_config):
|
|
||||||
return x + jit_config["bias_mutate_rate"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
28
examples/jit_xor.py
Normal file
28
examples/jit_xor.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from jit_pipeline import Pipeline
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return np.array(fitnesses) # returns a list
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Configer.load_config("xor.ini")
|
||||||
|
pipeline = Pipeline(config, seed=6)
|
||||||
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
|
print(nodes, cons)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
from neat.pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
@@ -21,6 +21,8 @@ def main():
|
|||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
pipeline = Pipeline(config, seed=6)
|
pipeline = Pipeline(config, seed=6)
|
||||||
nodes, cons = pipeline.auto_run(evaluate)
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
|
print(nodes, cons)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
from .genome import create_forward, topological_sort, unflatten_connections
|
from algorithms.neat import create_forward, topological_sort, \
|
||||||
from .operations import create_next_generation_then_speciate
|
unflatten_connections, create_next_generation_then_speciate
|
||||||
|
|
||||||
|
|
||||||
def hash_symbols(symbols):
|
def hash_symbols(symbols):
|
||||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||||
@@ -32,7 +33,6 @@ class FunctionFactory:
|
|||||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
self.function_info = {
|
self.function_info = {
|
||||||
"pop_unflatten_connections": {
|
"pop_unflatten_connections": {
|
||||||
'func': vmap(unflatten_connections),
|
'func': vmap(unflatten_connections),
|
||||||
@@ -54,7 +54,7 @@ class FunctionFactory:
|
|||||||
'func': batch_forward,
|
'func': batch_forward,
|
||||||
'lowers': [
|
'lowers': [
|
||||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||||
{'shape': ('N', ), 'type': np.int32},
|
{'shape': ('N',), 'type': np.int32},
|
||||||
{'shape': ('N', 5), 'type': np.float32},
|
{'shape': ('N', 5), 'type': np.float32},
|
||||||
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
||||||
]
|
]
|
||||||
@@ -83,23 +83,22 @@ class FunctionFactory:
|
|||||||
'create_next_generation_then_speciate': {
|
'create_next_generation_then_speciate': {
|
||||||
'func': create_next_generation_then_speciate,
|
'func': create_next_generation_then_speciate,
|
||||||
'lowers': [
|
'lowers': [
|
||||||
{'shape': (2, ), 'type': np.uint32}, # rand_key
|
{'shape': (2,), 'type': np.uint32}, # rand_key
|
||||||
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||||
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||||
{'shape': ('P', ), 'type': np.int32}, # winner
|
{'shape': ('P',), 'type': np.int32}, # winner
|
||||||
{'shape': ('P', ), 'type': np.int32}, # loser
|
{'shape': ('P',), 'type': np.int32}, # loser
|
||||||
{'shape': ('P', ), 'type': bool}, # elite_mask
|
{'shape': ('P',), 'type': bool}, # elite_mask
|
||||||
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||||
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||||
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||||
{'shape': ('S', ), 'type': np.int32}, # species_keys
|
{'shape': ('S',), 'type': np.int32}, # species_keys
|
||||||
{'shape': (), 'type': np.int32}, # new_species_key_start
|
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||||
"jit_config"
|
"jit_config"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get(self, name, symbols):
|
def get(self, name, symbols):
|
||||||
if (name, hash_symbols(symbols)) not in self.func_dict:
|
if (name, hash_symbols(symbols)) not in self.func_dict:
|
||||||
self.compile(name, symbols)
|
self.compile(name, symbols)
|
||||||
159
jit_pipeline.py
Normal file
159
jit_pipeline.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import time
|
||||||
|
from typing import Union, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from function_factory import FunctionFactory
|
||||||
|
from algorithms.neat import initialize_genomes, expand, expand_single
|
||||||
|
|
||||||
|
from algorithms.neat.jit_species import update_species
|
||||||
|
from algorithms.neat.operations import create_next_generation_then_speciate
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
Neat algorithm pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, function_factory=None, seed=42):
|
||||||
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
self.config = config # global config
|
||||||
|
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||||
|
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||||
|
|
||||||
|
self.symbols = {
|
||||||
|
'P': self.config['pop_size'],
|
||||||
|
'N': self.config['init_maximum_nodes'],
|
||||||
|
'C': self.config['init_maximum_connections'],
|
||||||
|
'S': self.config['init_maximum_species'],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.generation = 0
|
||||||
|
self.best_genome = None
|
||||||
|
|
||||||
|
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||||
|
self.species_info = np.full((self.symbols['S'], 3), np.nan)
|
||||||
|
self.species_info[0, :] = 0, -np.inf, 0
|
||||||
|
self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32)
|
||||||
|
self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan)
|
||||||
|
self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan)
|
||||||
|
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
|
||||||
|
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
||||||
|
|
||||||
|
self.best_fitness = float('-inf')
|
||||||
|
self.best_genome = None
|
||||||
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
|
self.evaluate_time = 0
|
||||||
|
print(self.config)
|
||||||
|
|
||||||
|
def ask(self):
|
||||||
|
"""
|
||||||
|
Creates a function that receives a genome and returns a forward function.
|
||||||
|
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
|
||||||
|
|
||||||
|
single:
|
||||||
|
Create pop_size number of forward functions.
|
||||||
|
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
||||||
|
e.g. RL task
|
||||||
|
|
||||||
|
pop:
|
||||||
|
Create a single forward function, which use only once calculation for the population.
|
||||||
|
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||||
|
|
||||||
|
common:
|
||||||
|
Special case of pop. The population has the same inputs.
|
||||||
|
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||||
|
e.g. numerical regression; Hyper-NEAT
|
||||||
|
|
||||||
|
"""
|
||||||
|
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
|
||||||
|
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
|
if self.config['forward_way'] == 'single':
|
||||||
|
forward_funcs = []
|
||||||
|
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
|
||||||
|
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
|
||||||
|
forward_funcs.append(func)
|
||||||
|
return forward_funcs
|
||||||
|
|
||||||
|
elif self.config['forward_way'] == 'pop':
|
||||||
|
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||||
|
return func
|
||||||
|
|
||||||
|
elif self.config['forward_way'] == 'common':
|
||||||
|
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||||
|
return func
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tell(self, fitnesses):
|
||||||
|
self.generation += 1
|
||||||
|
|
||||||
|
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||||
|
update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||||
|
self.center_cons, self.generation, self.jit_config)
|
||||||
|
|
||||||
|
# node keys to be used in the mutation process
|
||||||
|
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||||
|
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||||
|
|
||||||
|
# create the next generation and then speciate the population
|
||||||
|
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||||
|
create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||||
|
center_cons, species_keys, species_key_start, self.jit_config)
|
||||||
|
|
||||||
|
# carry data to cpu
|
||||||
|
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||||
|
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||||
|
|
||||||
|
# update randkey
|
||||||
|
self.randkey = jax.random.split(self.randkey)[0]
|
||||||
|
|
||||||
|
def get_func(self, name):
|
||||||
|
return self.function_factory.get(name, self.symbols)
|
||||||
|
|
||||||
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
|
for _ in range(self.config['generation_limit']):
|
||||||
|
forward_func = self.ask()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
fitnesses = fitness_func(forward_func)
|
||||||
|
self.evaluate_time += time.time() - tic
|
||||||
|
|
||||||
|
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||||
|
|
||||||
|
if analysis is not None:
|
||||||
|
if analysis == "default":
|
||||||
|
self.default_analysis(fitnesses)
|
||||||
|
else:
|
||||||
|
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||||
|
analysis(fitnesses)
|
||||||
|
|
||||||
|
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||||
|
print("Fitness limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
|
self.tell(fitnesses)
|
||||||
|
print("Generation limit reached!")
|
||||||
|
return self.best_genome
|
||||||
|
|
||||||
|
def default_analysis(self, fitnesses):
|
||||||
|
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||||
|
|
||||||
|
new_timestamp = time.time()
|
||||||
|
cost_time = new_timestamp - self.generation_timestamp
|
||||||
|
self.generation_timestamp = new_timestamp
|
||||||
|
|
||||||
|
max_idx = np.argmax(fitnesses)
|
||||||
|
if fitnesses[max_idx] > self.best_fitness:
|
||||||
|
self.best_fitness = fitnesses[max_idx]
|
||||||
|
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||||
|
|
||||||
|
print(f"Generation: {self.generation}",
|
||||||
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
|
||||||
"""
|
|
||||||
@@ -5,9 +5,8 @@ import numpy as np
|
|||||||
import jax
|
import jax
|
||||||
|
|
||||||
from configs import Configer
|
from configs import Configer
|
||||||
from .genome import initialize_genomes, expand, expand_single
|
from function_factory import FunctionFactory
|
||||||
from .function_factory import FunctionFactory
|
from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController
|
||||||
from .species import SpeciesController
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
@@ -119,25 +118,27 @@ class Pipeline:
|
|||||||
when the maximum node number >= N or the maximum connection number of >= C
|
when the maximum node number >= N or the maximum connection number of >= C
|
||||||
the population will expand
|
the population will expand
|
||||||
"""
|
"""
|
||||||
changed = False
|
|
||||||
|
|
||||||
|
# analysis nodes
|
||||||
pop_node_keys = self.pop_nodes[:, :, 0]
|
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||||
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||||
max_node_size = np.max(pop_node_sizes)
|
max_node_size = np.max(pop_node_sizes)
|
||||||
if max_node_size >= self.symbols['N']:
|
|
||||||
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
|
||||||
print(f"node expand to {self.symbols['N']}!")
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
|
# analysis connections
|
||||||
pop_con_keys = self.pop_cons[:, :, 0]
|
pop_con_keys = self.pop_cons[:, :, 0]
|
||||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||||
max_con_size = np.max(pop_node_sizes)
|
max_con_size = np.max(pop_node_sizes)
|
||||||
if max_con_size >= self.symbols['C']:
|
|
||||||
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
|
||||||
print(f"connection expand to {self.symbols['C']}!")
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
if changed:
|
# expand if needed
|
||||||
|
if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']:
|
||||||
|
if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']:
|
||||||
|
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||||
|
print(f"pre node expand to {self.symbols['N']}!")
|
||||||
|
|
||||||
|
if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']:
|
||||||
|
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||||
|
print(f"pre connection expand to {self.symbols['C']}!")
|
||||||
|
|
||||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
||||||
# don't forget to expand representation genome in species
|
# don't forget to expand representation genome in species
|
||||||
for s in self.species_controller.species.values():
|
for s in self.species_controller.species.values():
|
||||||
@@ -160,7 +161,7 @@ class Pipeline:
|
|||||||
if analysis == "default":
|
if analysis == "default":
|
||||||
self.default_analysis(fitnesses)
|
self.default_analysis(fitnesses)
|
||||||
else:
|
else:
|
||||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
assert callable(analysis), f"Callable is needed here😅😅😅 A {analysis}?"
|
||||||
analysis(fitnesses)
|
analysis(fitnesses)
|
||||||
|
|
||||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||||
Reference in New Issue
Block a user