modifying

This commit is contained in:
wls2002
2023-06-27 18:47:47 +08:00
parent ba369db0b2
commit 114ff2b0cc
28 changed files with 451 additions and 123 deletions

View File

@@ -0,0 +1,6 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
from .operations import create_next_generation_then_speciate
from .species import SpeciesController

View File

View File

@@ -5,8 +5,11 @@ from jax import jit, vmap
from .utils import I_INT from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
def create_forward(config): def create_forward(config):
"""
meta method to create forward function
"""
def act(idx, z): def act(idx, z):
""" """
calculate activation function for each node calculate activation function for each node

View File

@@ -4,11 +4,11 @@ Only used in feed-forward networks.
""" """
import jax import jax
from jax import jit, vmap, Array from jax import jit, Array
from jax import numpy as jnp from jax import numpy as jnp
# from .configs import fetch_first, I_INT # from .configs import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT, unflatten_connections from algorithms.neat.genome.utils import fetch_first, I_INT
@jit @jit

View File

@@ -1,3 +1,5 @@
from functools import partial
import numpy as np import numpy as np
import jax import jax
from jax import numpy as jnp, Array from jax import numpy as jnp, Array
@@ -30,6 +32,7 @@ def unflatten_connections(nodes: Array, cons: Array):
return res return res
def key_to_indices(key, keys): def key_to_indices(key, keys):
return fetch_first(key == keys) return fetch_first(key == keys)
@@ -56,4 +59,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target) mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default) return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
"""
if reverse:
array = -array
return jnp.argsort(jnp.argsort(array))

View File

@@ -0,0 +1,160 @@
from functools import partial
import jax
from jax import jit, numpy as jnp, vmap
from .genome.utils import rank_elements
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 3), float], the information of each species
[species_key, best_score, last_update]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
generation: int, current generation
jit_config: Dict, the configuration of jit functions
"""
# update the fitness of each species
species_fitness = update_species_fitness(species_info, idx2species, fitness)
# stagnation species
species_fitness, species_info, center_nodes, center_cons = \
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
species_info = species_info[sort_indices]
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
jax.debug.print("{}, {}", fitness, winner)
jax.debug.print("{}", fitness[winner])
return species_info, center_nodes, center_cons, winner, loser, elite_mask
def update_species_fitness(species_info, idx2species, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = species_info[idx, 0]
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update = species_info[idx]
# stagnation condition
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(st, jnp.nan, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
def cal_spawn_numbers(species_info, jit_config):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
def aux_func(key, idx):
members = idx2species == species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, jnp.nan)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = jit_config['genome_elitism']
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask

View File

@@ -1,13 +1,8 @@
""" """
contains operations on the population: creating the next generation and population speciation. contains operations on the population: creating the next generation and population speciation.
""" """
from functools import partial
import jax import jax
import jax.numpy as jnp from jax import jit, vmap, Array, numpy as jnp
from jax import jit, vmap
from jax import Array
from .genome import distance, mutate, crossover from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first from .genome.utils import I_INT, fetch_first

View File

@@ -8,7 +8,6 @@ See
""" """
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
from itertools import count
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray

View File

@@ -4,8 +4,8 @@ import configparser
import numpy as np import numpy as np
from neat.genome.activations import act_name2func from algorithms.neat.genome.activations import act_name2func
from neat.genome.aggregations import agg_name2func from algorithms.neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX. # Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [ jit_config_keys = [
@@ -41,6 +41,11 @@ jit_config_keys = [
"weight_mutate_rate", "weight_mutate_rate",
"weight_replace_rate", "weight_replace_rate",
"enable_mutate_rate", "enable_mutate_rate",
"max_stagnation",
"pop_size",
"genome_elitism",
"survival_threshold",
"species_elitism"
] ]

View File

@@ -1,10 +1,11 @@
[basic] [basic]
num_inputs = 2 num_inputs = 2
num_outputs = 1 num_outputs = 1
init_maximum_nodes = 50 init_maximum_nodes = 200
init_maximum_connections = 50 init_maximum_connections = 200
init_maximum_species = 10 init_maximum_species = 10
expand_coe = 2.0 expand_coe = 1.5
pre_expand_threshold = 0.75
forward_way = "pop" forward_way = "pop"
batch_size = 4 batch_size = 4
@@ -12,7 +13,7 @@ batch_size = 4
fitness_threshold = 100000 fitness_threshold = 100000
generation_limit = 100 generation_limit = 100
fitness_criterion = "max" fitness_criterion = "max"
pop_size = 15000 pop_size = 150
[genome] [genome]
compatibility_disjoint = 1.0 compatibility_disjoint = 1.0

View File

@@ -1,55 +0,0 @@
import numpy as np
import jax.numpy as jnp
import jax
a = {1:2, 2:3, 4:5}
print(a.values())
a = jnp.array([1, 0, 1, 0, np.nan])
b = jnp.array([1, 1, 1, 1, 1])
c = jnp.array([1, 1, 1, 1, 1])
full = jnp.array([
[1, 1, 1],
[0, 1, 1],
[1, 1, 1],
[0, 1, 1],
])
print(jnp.column_stack([a[:, None], b[:, None], c[:, None]]))
aux0 = full[:, 0, None]
aux1 = full[:, 1, None]
print(aux0, aux0.shape)
print(jnp.concatenate([aux0, aux1], axis=1))
f_a = jnp.array([False, False, True, True])
f_b = jnp.array([True, False, False, False])
print(jnp.logical_and(f_a, f_b))
print(f_a & f_b)
print(f_a + jnp.nan * 0.0)
print(f_a + 1 * 0.0)
@jax.jit
def main():
return func('happy') + func('sad')
def func(x):
if x == 'happy':
return 1
else:
return 2
a = jnp.zeros((3, 3))
print(a.dtype)
c = None
b = 1 or c
print(b)

26
examples/evox_test.py Normal file
View File

@@ -0,0 +1,26 @@
import jax
from jax import numpy as jnp
from evox import algorithms, problems, pipelines
from evox.monitors import StdSOMonitor
monitor = StdSOMonitor()
pso = algorithms.PSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=100,
)
ackley = problems.classic.Ackley()
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
key = jax.random.PRNGKey(42)
state = pipeline.init(key)
# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())

View File

@@ -1,27 +1,18 @@
import numpy as np from functools import partial
from jax import jit
from configs import Configer import jax
from neat.pipeline import Pipeline from jax import numpy as jnp, jit
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) @partial(jit, static_argnames=['reverse'])
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) def rank_element(array, reverse=False):
"""
def main(): rank the element in the array.
config = Configer.load_config("xor.ini") if reverse is True, the rank is from large to small.
print(config) """
pipeline = Pipeline(config) if reverse:
forward_func = pipeline.ask() array = -array
# inputs = np.tile(xor_inputs, (150, 1, 1)) return jnp.argsort(jnp.argsort(array))
outputs = forward_func(xor_inputs)
print(outputs)
a = jnp.array([1 ,5, 3, 5, 2, 1, 0])
@jit print(rank_element(a, reverse=True))
def f(x, jit_config):
return x + jit_config["bias_mutate_rate"]
if __name__ == '__main__':
main()

28
examples/jit_xor.py Normal file
View File

@@ -0,0 +1,28 @@
import numpy as np
from configs import Configer
from jit_pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return np.array(fitnesses) # returns a list
def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
if __name__ == '__main__':
main()

View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
from configs import Configer from configs import Configer
from neat.pipeline import Pipeline from pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,6 +21,8 @@ def main():
config = Configer.load_config("xor.ini") config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6) pipeline = Pipeline(config, seed=6)
nodes, cons = pipeline.auto_run(evaluate) nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -1,8 +1,9 @@
import numpy as np import numpy as np
from jax import jit, vmap from jax import jit, vmap
from .genome import create_forward, topological_sort, unflatten_connections from algorithms.neat import create_forward, topological_sort, \
from .operations import create_next_generation_then_speciate unflatten_connections, create_next_generation_then_speciate
def hash_symbols(symbols): def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S'] return symbols['P'], symbols['N'], symbols['C'], symbols['S']
@@ -32,7 +33,6 @@ class FunctionFactory:
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums) # (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0)) common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
self.function_info = { self.function_info = {
"pop_unflatten_connections": { "pop_unflatten_connections": {
'func': vmap(unflatten_connections), 'func': vmap(unflatten_connections),
@@ -54,7 +54,7 @@ class FunctionFactory:
'func': batch_forward, 'func': batch_forward,
'lowers': [ 'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32}, {'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('N', ), 'type': np.int32}, {'shape': ('N',), 'type': np.int32},
{'shape': ('N', 5), 'type': np.float32}, {'shape': ('N', 5), 'type': np.float32},
{'shape': (2, 'N', 'N'), 'type': np.float32} {'shape': (2, 'N', 'N'), 'type': np.float32}
] ]
@@ -83,23 +83,22 @@ class FunctionFactory:
'create_next_generation_then_speciate': { 'create_next_generation_then_speciate': {
'func': create_next_generation_then_speciate, 'func': create_next_generation_then_speciate,
'lowers': [ 'lowers': [
{'shape': (2, ), 'type': np.uint32}, # rand_key {'shape': (2,), 'type': np.uint32}, # rand_key
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes {'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons {'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
{'shape': ('P', ), 'type': np.int32}, # winner {'shape': ('P',), 'type': np.int32}, # winner
{'shape': ('P', ), 'type': np.int32}, # loser {'shape': ('P',), 'type': np.int32}, # loser
{'shape': ('P', ), 'type': bool}, # elite_mask {'shape': ('P',), 'type': bool}, # elite_mask
{'shape': ('P',), 'type': np.int32}, # new_node_keys {'shape': ('P',), 'type': np.int32}, # new_node_keys
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes {'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons {'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
{'shape': ('S', ), 'type': np.int32}, # species_keys {'shape': ('S',), 'type': np.int32}, # species_keys
{'shape': (), 'type': np.int32}, # new_species_key_start {'shape': (), 'type': np.int32}, # new_species_key_start
"jit_config" "jit_config"
] ]
} }
} }
def get(self, name, symbols): def get(self, name, symbols):
if (name, hash_symbols(symbols)) not in self.func_dict: if (name, hash_symbols(symbols)) not in self.func_dict:
self.compile(name, symbols) self.compile(name, symbols)

159
jit_pipeline.py Normal file
View File

@@ -0,0 +1,159 @@
import time
from typing import Union, Callable
import numpy as np
import jax
from configs import Configer
from function_factory import FunctionFactory
from algorithms.neat import initialize_genomes, expand, expand_single
from algorithms.neat.jit_species import update_species
from algorithms.neat.operations import create_next_generation_then_speciate
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, function_factory=None, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.function_factory = function_factory or FunctionFactory(self.config, self.jit_config)
self.symbols = {
'P': self.config['pop_size'],
'N': self.config['init_maximum_nodes'],
'C': self.config['init_maximum_connections'],
'S': self.config['init_maximum_species'],
}
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
self.species_info = np.full((self.symbols['S'], 3), np.nan)
self.species_info[0, :] = 0, -np.inf, 0
self.idx2species = np.zeros(self.symbols['P'], dtype=np.int32)
self.center_nodes = np.full((self.symbols['S'], self.symbols['N'], 5), np.nan)
self.center_cons = np.full((self.symbols['S'], self.symbols['C'], 4), np.nan)
self.center_nodes[0, :, :] = self.pop_nodes[0, :, :]
self.center_cons[0, :, :] = self.pop_cons[0, :, :]
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
print(self.config)
def ask(self):
"""
Creates a function that receives a genome and returns a forward function.
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
single:
Create pop_size number of forward functions.
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
e.g. RL task
pop:
Create a single forward function, which use only once calculation for the population.
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
common:
Special case of pop. The population has the same inputs.
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
e.g. numerical regression; Hyper-NEAT
"""
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
if self.config['forward_way'] == 'single':
forward_funcs = []
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
forward_funcs.append(func)
return forward_funcs
elif self.config['forward_way'] == 'pop':
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
elif self.config['forward_way'] == 'common':
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
else:
raise NotImplementedError
def tell(self, fitnesses):
self.generation += 1
species_info, center_nodes, center_cons, winner, loser, elite_mask = \
update_species(self.randkey, fitnesses, self.species_info, self.idx2species, self.center_nodes,
self.center_cons, self.generation, self.jit_config)
# node keys to be used in the mutation process
new_node_keys = np.arange(self.generation * self.config['pop_size'],
self.generation * self.config['pop_size'] + self.config['pop_size'])
# create the next generation and then speciate the population
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
create_next_generation_then_speciate(self.randkey, self.pop_nodes, self.pop_cons, winner, loser, elite_mask, new_node_keys, center_nodes,
center_cons, species_keys, species_key_start, self.jit_config)
# carry data to cpu
self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, center_nodes, center_cons, species_keys])
# update randkey
self.randkey = jax.random.split(self.randkey)[0]
def get_func(self, name):
return self.function_factory.get(name, self.symbols)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")

View File

@@ -1,3 +0,0 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""

View File

@@ -5,9 +5,8 @@ import numpy as np
import jax import jax
from configs import Configer from configs import Configer
from .genome import initialize_genomes, expand, expand_single from function_factory import FunctionFactory
from .function_factory import FunctionFactory from algorithms.neat import initialize_genomes, expand, expand_single, SpeciesController
from .species import SpeciesController
class Pipeline: class Pipeline:
@@ -119,25 +118,27 @@ class Pipeline:
when the maximum node number >= N or the maximum connection number of >= C when the maximum node number >= N or the maximum connection number of >= C
the population will expand the population will expand
""" """
changed = False
# analysis nodes
pop_node_keys = self.pop_nodes[:, :, 0] pop_node_keys = self.pop_nodes[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1) pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes) max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.symbols['N']:
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
print(f"node expand to {self.symbols['N']}!")
changed = True
# analysis connections
pop_con_keys = self.pop_cons[:, :, 0] pop_con_keys = self.pop_cons[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1) pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes) max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.symbols['C']:
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
print(f"connection expand to {self.symbols['C']}!")
changed = True
if changed: # expand if needed
if max_node_size >= self.symbols['N'] or max_con_size >= self.symbols['C']:
if max_node_size > self.symbols['N'] * self.config['pre_expand_threshold']:
self.symbols['N'] = int(self.symbols['N'] * self.config['expand_coe'])
print(f"pre node expand to {self.symbols['N']}!")
if max_con_size > self.symbols['C'] * self.config['pre_expand_threshold']:
self.symbols['C'] = int(self.symbols['C'] * self.config['expand_coe'])
print(f"pre connection expand to {self.symbols['C']}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C']) self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.symbols['N'], self.symbols['C'])
# don't forget to expand representation genome in species # don't forget to expand representation genome in species
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
@@ -160,7 +161,7 @@ class Pipeline:
if analysis == "default": if analysis == "default":
self.default_analysis(fitnesses) self.default_analysis(fitnesses)
else: else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?" assert callable(analysis), f"Callable is needed here😅😅😅 A {analysis}?"
analysis(fitnesses) analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']: if max(fitnesses) >= self.config['fitness_threshold']: