modifying
This commit is contained in:
6
algorithms/neat/__init__.py
Normal file
6
algorithms/neat/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
|
||||
from .operations import create_next_generation_then_speciate
|
||||
from .species import SpeciesController
|
||||
0
algorithms/neat/genome/debug/__init__.py
Normal file
0
algorithms/neat/genome/debug/__init__.py
Normal file
@@ -5,8 +5,11 @@ from jax import jit, vmap
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||
def create_forward(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
@@ -4,11 +4,11 @@ Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
# from .configs import fetch_first, I_INT
|
||||
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
|
||||
from algorithms.neat.genome.utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
@@ -1,3 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
@@ -30,6 +32,7 @@ def unflatten_connections(nodes: Array, cons: Array):
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
@@ -56,4 +59,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
"""
|
||||
if reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
160
algorithms/neat/jit_species.py
Normal file
160
algorithms/neat/jit_species.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import jit, numpy as jnp, vmap
|
||||
|
||||
from .genome.utils import rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
randkey: random key
|
||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||
species_keys: Array[(species_size, 3), float], the information of each species
|
||||
[species_key, best_score, last_update]
|
||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||
generation: int, current generation
|
||||
jit_config: Dict, the configuration of jit functions
|
||||
"""
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = update_species_fitness(species_info, idx2species, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_fitness, species_info, center_nodes, center_cons = \
|
||||
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
|
||||
|
||||
# sort species_info by their fitness. (push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
species_info = species_info[sort_indices]
|
||||
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||
|
||||
# crossover info
|
||||
winner, loser, elite_mask = \
|
||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||
|
||||
jax.debug.print("{}, {}", fitness, winner)
|
||||
jax.debug.print("{}", fitness[winner])
|
||||
|
||||
return species_info, center_nodes, center_cons, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(species_info, idx2species, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
species_key = species_info[idx, 0]
|
||||
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
|
||||
f = jnp.max(s_fitness)
|
||||
return f
|
||||
|
||||
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = species_fitness[idx]
|
||||
species_key, best_score, last_update = species_info[idx]
|
||||
# stagnation condition
|
||||
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||
|
||||
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
species_info = jnp.where(st[:, None], jnp.nan, species_info)
|
||||
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
|
||||
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
|
||||
species_fitness = jnp.where(st, jnp.nan, species_fitness)
|
||||
|
||||
return species_fitness, species_info, center_nodes, center_cons
|
||||
|
||||
|
||||
def cal_spawn_numbers(species_info, jit_config):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_info[:, 0])
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||
|
||||
species_size = species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
s_idx = jnp.arange(species_size)
|
||||
p_idx = jnp.arange(pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
members = idx2species == species_info[idx, 0]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, jnp.nan)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
elite_size = jit_config['genome_elitism']
|
||||
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < elite_size, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
||||
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
@@ -1,13 +1,8 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from jax import Array
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first
|
||||
@@ -8,7 +8,6 @@ See
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
from itertools import count
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@@ -4,8 +4,8 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from neat.genome.activations import act_name2func
|
||||
from neat.genome.aggregations import agg_name2func
|
||||
from algorithms.neat.genome.activations import act_name2func
|
||||
from algorithms.neat.genome.aggregations import agg_name2func
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
@@ -41,6 +41,11 @@ jit_config_keys = [
|
||||
"weight_mutate_rate",
|
||||
"weight_replace_rate",
|
||||
"enable_mutate_rate",
|
||||
"max_stagnation",
|
||||
"pop_size",
|
||||
"genome_elitism",
|
||||
"survival_threshold",
|
||||
"species_elitism"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
init_maximum_nodes = 50
|
||||
init_maximum_connections = 50
|
||||
init_maximum_nodes = 200
|
||||
init_maximum_connections = 200
|
||||
init_maximum_species = 10
|
||||
expand_coe = 2.0
|
||||
expand_coe = 1.5
|
||||
pre_expand_threshold = 0.75
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
|
||||
@@ -12,7 +13,7 @@ batch_size = 4
|
||||
fitness_threshold = 100000
|
||||
generation_limit = 100
|
||||
fitness_criterion = "max"
|
||||
pop_size = 15000
|
||||
pop_size = 150
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
|
||||
a = {1:2, 2:3, 4:5}
|
||||
print(a.values())
|
||||
|
||||
a = jnp.array([1, 0, 1, 0, np.nan])
|
||||
b = jnp.array([1, 1, 1, 1, 1])
|
||||
c = jnp.array([1, 1, 1, 1, 1])
|
||||
|
||||
full = jnp.array([
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1],
|
||||
])
|
||||
|
||||
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
|
||||
|
||||
aux0 = full[:, 0, None]
|
||||
aux1 = full[:, 1, None]
|
||||
|
||||
print(aux0, aux0.shape)
|
||||
|
||||
print(jnp.concatenate([aux0, aux1], axis=1))
|
||||
|
||||
f_a = jnp.array([False, False, True, True])
|
||||
f_b = jnp.array([True, False, False, False])
|
||||
|
||||
print(jnp.logical_and(f_a, f_b))
|
||||
print(f_a & f_b)
|
||||
|
||||
print(f_a + jnp.nan * 0.0)
|
||||
print(f_a + 1 * 0.0)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def main():
|
||||
return func('happy') + func('sad')
|
||||
|
||||
|
||||
def func(x):
|
||||
if x == 'happy':
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
a = jnp.zeros((3, 3))
|
||||
print(a.dtype)
|
||||
|
||||
c = None
|
||||
b = 1 or c
|
||||
print(b)
|
||||
26
examples/evox_test.py
Normal file
26
examples/evox_test.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from evox import algorithms, problems, pipelines
|
||||
from evox.monitors import StdSOMonitor
|
||||
|
||||
monitor = StdSOMonitor()
|
||||
|
||||
pso = algorithms.PSO(
|
||||
lb=jnp.full(shape=(2,), fill_value=-32),
|
||||
ub=jnp.full(shape=(2,), fill_value=32),
|
||||
pop_size=100,
|
||||
)
|
||||
|
||||
ackley = problems.classic.Ackley()
|
||||
|
||||
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 100 steps
|
||||
for i in range(100):
|
||||
state = pipeline.step(state)
|
||||
|
||||
print(monitor.get_min_fitness())
|
||||
@@ -1,27 +1,18 @@
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
|
||||
from configs import Configer
|
||||
from neat.pipeline import Pipeline
|
||||
import jax
|
||||
from jax import numpy as jnp, jit
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
print(config)
|
||||
pipeline = Pipeline(config)
|
||||
forward_func = pipeline.ask()
|
||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
||||
outputs = forward_func(xor_inputs)
|
||||
print(outputs)
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_element(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
"""
|
||||
if reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
def f(x, jit_config):
|
||||
return x + jit_config["bias_mutate_rate"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
a = jnp.array([1 ,5, 3, 5, 2, 1, 0])
|
||||
print(rank_element(a, reverse=True))
|
||||
28
examples/jit_xor.py
Normal file
28
examples/jit_xor.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from jit_pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return np.array(fitnesses) # returns a list
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
print(nodes, cons)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from neat.pipeline import Pipeline
|
||||
from pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -21,6 +21,8 @@ def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config, seed=6)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
print(nodes, cons)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import create_forward, topological_sort, unflatten_connections
|
||||
from .operations import create_next_generation_then_speciate
|
||||
from algorithms.neat import create_forward, topological_sort, \
|
||||
unflatten_connections, create_next_generation_then_speciate
|
||||
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
@@ -32,7 +33,6 @@ class FunctionFactory:
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
|
||||
self.function_info = {
|
||||
"pop_unflatten_connections": {
|
||||
'func': vmap(unflatten_connections),
|
||||
@@ -54,7 +54,7 @@ class FunctionFactory:
|
||||
'func': batch_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('N', ), 'type': np.int32},
|
||||
{'shape': ('N',), 'type': np.int32},
|
||||
{'shape': ('N', 5), 'type': np.float32},
|
||||
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
@@ -83,23 +83,22 @@ class FunctionFactory:
|
||||
'create_next_generation_then_speciate': {
|
||||
'func': create_next_generation_then_speciate,
|
||||
'lowers': [
|
||||
{'shape': (2, ), 'type': np.uint32}, # rand_key
|
||||
{'shape': (2,), 'type': np.uint32}, # rand_key
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||
{'shape': ('P', ), 'type': np.int32}, # winner
|
||||
{'shape': ('P', ), 'type': np.int32}, # loser
|
||||
{'shape': ('P', ), 'type': bool}, # elite_mask
|
||||
{'shape': ('P',), 'type': np.int32}, # winner
|
||||
{'shape': ('P',), 'type': np.int32}, # loser
|
||||
{'shape': ('P',), 'type': bool}, # elite_mask
|
||||
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||
{'shape': ('S', ), 'type': np.int32}, # species_keys
|
||||
{'shape': ('S',), 'type': np.int32}, # species_keys
|
||||
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||
"jit_config"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get(self, name, symbols):
|
||||
if (name, hash_symbols(symbols)) not in self.func_dict:
|
||||
self.compile(name, symbols)
|
||||
159
jit_pipeline.py
Normal file
159
jit_pipeline.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import time
|
||||
from typing import Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
|
||||
from configs import Configer
|
||||
from function_factory import FunctionFactory
|
||||
from algorithms.neat import initialize_genomes, expand, expand_single
|
||||
|
||||
from algorithms.neat.jit_species import update_species
|
||||
from algorithms.neat.operations import create_next_generation_then_speciate
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, function_factory=None, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
|
||||
|
||||
self.symbols = {
|
||||
'P': self.config['pop_size'],
|
||||
'N': self.config['init_maximum_nodes'],
|
||||
'C': self.config['init_maximum_connections'],
|
||||
'S': self.config['init_maximum_species'],
|
||||
}
|
||||
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||
self.species_info = np.full((self.symbols['S'], 3), np.nan)
|
||||
self.species_info[0, :] = 0, -np.inf, 0
|
||||
self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32)
|
||||
self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan)
|
||||
self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan)
|
||||
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
|
||||
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
print(self.config)
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Creates a function that receives a genome and returns a forward function.
|
||||
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
|
||||
|
||||
single:
|
||||
Create pop_size number of forward functions.
|
||||
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
||||
e.g. RL task
|
||||
|
||||
pop:
|
||||
Create a single forward function, which use only once calculation for the population.
|
||||
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
|
||||
common:
|
||||
Special case of pop. The population has the same inputs.
|
||||
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
e.g. numerical regression; Hyper-NEAT
|
||||
|
||||
"""
|
||||
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
|
||||
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
|
||||
|
||||
if self.config['forward_way'] == 'single':
|
||||
forward_funcs = []
|
||||
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
|
||||
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
|
||||
forward_funcs.append(func)
|
||||
return forward_funcs
|
||||
|
||||
elif self.config['forward_way'] == 'pop':
|
||||
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
elif self.config['forward_way'] == 'common':
|
||||
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
|
||||
update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes,
|
||||
self.center_cons, self.generation, self.jit_config)
|
||||
|
||||
# node keys to be used in the mutation process
|
||||
new_node_keys = np.arange(self.generation * self.config['pop_size'],
|
||||
self.generation * self.config['pop_size'] + self.config['pop_size'])
|
||||
|
||||
# create the next generation and then speciate the population
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
|
||||
center_cons, species_keys, species_key_start, self.jit_config)
|
||||
|
||||
# carry data to cpu
|
||||
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
|
||||
|
||||
# update randkey
|
||||
self.randkey = jax.random.split(self.randkey)[0]
|
||||
|
||||
def get_func(self, name):
|
||||
return self.function_factory.get(name, self.symbols)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
forward_func = self.ask()
|
||||
|
||||
tic = time.time()
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||
@@ -1,3 +0,0 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
@@ -5,9 +5,8 @@ import numpy as np
|
||||
import jax
|
||||
|
||||
from configs import Configer
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
from .function_factory import FunctionFactory
|
||||
from .species import SpeciesController
|
||||
from function_factory import FunctionFactory
|
||||
from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -119,25 +118,27 @@ class Pipeline:
|
||||
when the maximum node number >= N or the maximum connection number of >= C
|
||||
the population will expand
|
||||
"""
|
||||
changed = False
|
||||
|
||||
# analysis nodes
|
||||
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
if max_node_size >= self.symbols['N']:
|
||||
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||
print(f"node expand to {self.symbols['N']}!")
|
||||
changed = True
|
||||
|
||||
# analysis connections
|
||||
pop_con_keys = self.pop_cons[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.symbols['C']:
|
||||
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||
print(f"connection expand to {self.symbols['C']}!")
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
# expand if needed
|
||||
if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']:
|
||||
if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']:
|
||||
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
|
||||
print(f"pre node expand to {self.symbols['N']}!")
|
||||
|
||||
if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']:
|
||||
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
|
||||
print(f"pre connection expand to {self.symbols['C']}!")
|
||||
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
@@ -160,7 +161,7 @@ class Pipeline:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
assert callable(analysis), f"Callable is needed here😅😅😅 A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||
Reference in New Issue
Block a user