finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
from neat.genome.activations import *
|
||||
|
||||
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
|
||||
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
|
||||
|
||||
act_name2key = {
|
||||
'sigmoid': 0,
|
||||
'tanh': 1,
|
||||
'sin': 2,
|
||||
'gauss': 3,
|
||||
'relu': 4,
|
||||
'elu': 5,
|
||||
'lelu': 6,
|
||||
'selu': 7,
|
||||
'softplus': 8,
|
||||
'identity': 9,
|
||||
'clamped': 10,
|
||||
'inv': 11,
|
||||
'log': 12,
|
||||
'exp': 13,
|
||||
'abs': 14,
|
||||
'hat': 15,
|
||||
'square': 16,
|
||||
'cube': 17,
|
||||
}
|
||||
|
||||
|
||||
def refactor_act(config):
|
||||
config['activation_default'] = act_name2key[config['activation_default']]
|
||||
config['activation_options'] = [
|
||||
act_name2key[act_name] for act_name in config['activation_options']
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from neat.genome.aggregations import *
|
||||
|
||||
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
||||
|
||||
agg_name2key = {
|
||||
'sum': 0,
|
||||
'product': 1,
|
||||
'max': 2,
|
||||
'min': 3,
|
||||
'maxabs': 4,
|
||||
'median': 5,
|
||||
'mean': 6,
|
||||
}
|
||||
|
||||
|
||||
def refactor_agg(config):
|
||||
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
|
||||
config['aggregation_options'] = [
|
||||
agg_name2key[act_name] for act_name in config['aggregation_options']
|
||||
]
|
||||
@@ -4,8 +4,8 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .activations import refactor_act
|
||||
from .aggregations import refactor_agg
|
||||
from neat.genome.activations import act_name2func
|
||||
from neat.genome.aggregations import agg_name2func
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
@@ -20,12 +20,12 @@ jit_config_keys = [
|
||||
"node_delete_prob",
|
||||
"compatibility_threshold",
|
||||
"bias_init_mean",
|
||||
"bias_init_stdev",
|
||||
"bias_init_std",
|
||||
"bias_mutate_power",
|
||||
"bias_mutate_rate",
|
||||
"bias_replace_rate",
|
||||
"response_init_mean",
|
||||
"response_init_stdev",
|
||||
"response_init_std",
|
||||
"response_mutate_power",
|
||||
"response_mutate_rate",
|
||||
"response_replace_rate",
|
||||
@@ -36,7 +36,7 @@ jit_config_keys = [
|
||||
"aggregation_options",
|
||||
"aggregation_replace_rate",
|
||||
"weight_init_mean",
|
||||
"weight_init_stdev",
|
||||
"weight_init_std",
|
||||
"weight_mutate_power",
|
||||
"weight_mutate_rate",
|
||||
"weight_replace_rate",
|
||||
@@ -90,14 +90,26 @@ class Configer:
|
||||
cls.__check_redundant_config(default_config, config)
|
||||
cls.__complete_config(default_config, config)
|
||||
|
||||
refactor_act(config)
|
||||
refactor_agg(config)
|
||||
input_idx = np.arange(config['num_inputs'])
|
||||
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
config['input_idx'] = input_idx
|
||||
config['output_idx'] = output_idx
|
||||
cls.refactor_activation(config)
|
||||
cls.refactor_aggregation(config)
|
||||
|
||||
config['input_idx'] = np.arange(config['num_inputs'])
|
||||
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def refactor_activation(cls, config):
|
||||
config['activation_default'] = 0
|
||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def refactor_aggregation(cls, config):
|
||||
config['aggregation_default'] = 0
|
||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def create_jit_config(cls, config):
|
||||
jit_config = {k: config[k] for k in jit_config_keys}
|
||||
|
||||
@@ -5,7 +5,8 @@ init_maximum_nodes = 20
|
||||
init_maximum_connections = 20
|
||||
init_maximum_species = 10
|
||||
expands_coe = 2.0
|
||||
forward_way = "pop_batch"
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
|
||||
[population]
|
||||
fitness_threshold = 100000
|
||||
@@ -46,12 +47,12 @@ response_replace_rate = 0.0
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_options = ["sigmoid"]
|
||||
activation_option_names = ["sigmoid"]
|
||||
activation_replace_rate = 0.0
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_options = ["sum"]
|
||||
aggregation_option_names = ["sum"]
|
||||
aggregation_replace_rate = 0.0
|
||||
|
||||
[gene-weight]
|
||||
|
||||
@@ -3,6 +3,9 @@ import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
|
||||
a = {1:2, 2:3, 4:5}
|
||||
print(a.values())
|
||||
|
||||
a = jnp.array([1, 0, 1, 0, np.nan])
|
||||
b = jnp.array([1, 1, 1, 1, 1])
|
||||
c = jnp.array([1, 1, 1, 1, 1])
|
||||
@@ -44,5 +47,9 @@ def func(x):
|
||||
else:
|
||||
return 2
|
||||
|
||||
a = jnp.zeros((3, 3))
|
||||
print(a.dtype)
|
||||
|
||||
print(main())
|
||||
c = None
|
||||
b = 1 or c
|
||||
print(b)
|
||||
@@ -1,16 +1,25 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit
|
||||
|
||||
from configs import Configer
|
||||
from neat.pipeline_ import Pipeline
|
||||
from neat.pipeline import Pipeline
|
||||
from neat.function_factory import FunctionFactory
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
print(config)
|
||||
pipeline = Pipeline(config)
|
||||
function_factory = FunctionFactory(config)
|
||||
pipeline = Pipeline(config, function_factory)
|
||||
forward_func = pipeline.ask()
|
||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
||||
outputs = forward_func(xor_inputs)
|
||||
print(outputs)
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import cProfile
|
||||
from io import StringIO
|
||||
import pstats
|
||||
|
||||
|
||||
def using_cprofile(func, root_abs_path=None, replace_pattern=None, save_path=None):
|
||||
def inner(*args, **kwargs):
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
ret = func(*args, **kwargs)
|
||||
pr.disable()
|
||||
profile_stats = StringIO()
|
||||
stats = pstats.Stats(pr, stream=profile_stats)
|
||||
if root_abs_path is not None:
|
||||
stats.sort_stats('cumulative').print_stats(root_abs_path)
|
||||
else:
|
||||
stats.sort_stats('cumulative').print_stats()
|
||||
output = profile_stats.getvalue()
|
||||
if replace_pattern is not None:
|
||||
output = output.replace(replace_pattern, "")
|
||||
if save_path is None:
|
||||
print(output)
|
||||
else:
|
||||
with open(save_path, "w") as f:
|
||||
f.write(output)
|
||||
return ret
|
||||
|
||||
return inner
|
||||
@@ -1,2 +1,5 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
|
||||
[population]
|
||||
fitness_threshold = -1e-2
|
||||
@@ -1,323 +1,108 @@
|
||||
"""
|
||||
Lowers, compiles, and creates functions used in the NEAT pipeline.
|
||||
"""
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import act_name2key, agg_name2key, initialize_genomes
|
||||
from .genome import topological_sort, forward_single, unflatten_connections
|
||||
from .population import create_next_generation_then_speciate
|
||||
from .genome.forward import create_forward
|
||||
from .genome.utils import unflatten_connections
|
||||
from .genome.graph import topological_sort
|
||||
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
"""
|
||||
Creates and compiles functions used in the NEAT pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.func_dict = {}
|
||||
self.function_info = {}
|
||||
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
self.compile_time = 0
|
||||
# (inputs_nums, ) -> (outputs_nums, )
|
||||
forward = create_forward(config) # input size (inputs_nums, )
|
||||
|
||||
self.load_config_vals(config)
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
self.create_topological_sort_with_args()
|
||||
self.create_single_forward_with_args()
|
||||
self.create_update_speciate_with_args()
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
def load_config_vals(self, config):
|
||||
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
self.problem_batch = config.basic.problem_batch
|
||||
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
self.function_info = {
|
||||
"pop_unflatten_connections": {
|
||||
'func': vmap(unflatten_connections),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
"pop_topological_sort": {
|
||||
'func': vmap(topological_sort),
|
||||
'lowers': [
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
|
||||
]
|
||||
},
|
||||
|
||||
self.num_inputs = config.basic.num_inputs
|
||||
self.num_outputs = config.basic.num_outputs
|
||||
self.input_idx = np.arange(self.num_inputs)
|
||||
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
|
||||
"batch_forward": {
|
||||
'func': batch_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('N', ), 'type': np.int32},
|
||||
{'shape': ('N', 5), 'type': np.float32},
|
||||
{'shape': (2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
bias = config.neat.gene.bias
|
||||
self.bias_mean = bias.init_mean
|
||||
self.bias_std = bias.init_stdev
|
||||
self.bias_mutate_strength = bias.mutate_power
|
||||
self.bias_mutate_rate = bias.mutate_rate
|
||||
self.bias_replace_rate = bias.replace_rate
|
||||
"pop_batch_forward": {
|
||||
'func': pop_batch_forward,
|
||||
'lowers': [
|
||||
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
response = config.neat.gene.response
|
||||
self.response_mean = response.init_mean
|
||||
self.response_std = response.init_stdev
|
||||
self.response_mutate_strength = response.mutate_power
|
||||
self.response_mutate_rate = response.mutate_rate
|
||||
self.response_replace_rate = response.replace_rate
|
||||
|
||||
weight = config.neat.gene.weight
|
||||
self.weight_mean = weight.init_mean
|
||||
self.weight_std = weight.init_stdev
|
||||
self.weight_mutate_strength = weight.mutate_power
|
||||
self.weight_mutate_rate = weight.mutate_rate
|
||||
self.weight_replace_rate = weight.replace_rate
|
||||
|
||||
activation = config.neat.gene.activation
|
||||
self.act_default = act_name2key[activation.default]
|
||||
self.act_list = np.array([act_name2key[name] for name in activation.options])
|
||||
self.act_replace_rate = activation.mutate_rate
|
||||
|
||||
aggregation = config.neat.gene.aggregation
|
||||
self.agg_default = agg_name2key[aggregation.default]
|
||||
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
||||
self.agg_replace_rate = aggregation.mutate_rate
|
||||
|
||||
enabled = config.neat.gene.enabled
|
||||
self.enabled_reverse_rate = enabled.mutate_rate
|
||||
|
||||
genome = config.neat.genome
|
||||
self.add_node_rate = genome.node_add_prob
|
||||
self.delete_node_rate = genome.node_delete_prob
|
||||
self.add_connection_rate = genome.conn_add_prob
|
||||
self.delete_connection_rate = genome.conn_delete_prob
|
||||
self.single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def create_initialize(self, N, C):
|
||||
func = partial(
|
||||
initialize_genomes,
|
||||
pop_size=self.pop_size,
|
||||
N=N,
|
||||
C=C,
|
||||
num_inputs=self.num_inputs,
|
||||
num_outputs=self.num_outputs,
|
||||
default_bias=self.bias_mean,
|
||||
default_response=self.response_mean,
|
||||
default_act=self.act_default,
|
||||
default_agg=self.agg_default,
|
||||
default_weight=self.weight_mean
|
||||
)
|
||||
return func
|
||||
|
||||
def create_update_speciate_with_args(self):
|
||||
species_kwargs = {
|
||||
"disjoint_coe": self.disjoint_coe,
|
||||
"compatibility_coe": self.compatibility_coe,
|
||||
"compatibility_threshold": self.compatibility_threshold
|
||||
'common_forward': {
|
||||
'func': common_forward,
|
||||
'lowers': [
|
||||
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
|
||||
{'shape': ('P', 'N'), 'type': np.int32},
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mutate_kwargs = {
|
||||
"input_idx": self.input_idx,
|
||||
"output_idx": self.output_idx,
|
||||
"bias_mean": self.bias_mean,
|
||||
"bias_std": self.bias_std,
|
||||
"bias_mutate_strength": self.bias_mutate_strength,
|
||||
"bias_mutate_rate": self.bias_mutate_rate,
|
||||
"bias_replace_rate": self.bias_replace_rate,
|
||||
"response_mean": self.response_mean,
|
||||
"response_std": self.response_std,
|
||||
"response_mutate_strength": self.response_mutate_strength,
|
||||
"response_mutate_rate": self.response_mutate_rate,
|
||||
"response_replace_rate": self.response_replace_rate,
|
||||
"weight_mean": self.weight_mean,
|
||||
"weight_std": self.weight_std,
|
||||
"weight_mutate_strength": self.weight_mutate_strength,
|
||||
"weight_mutate_rate": self.weight_mutate_rate,
|
||||
"weight_replace_rate": self.weight_replace_rate,
|
||||
"act_default": self.act_default,
|
||||
"act_list": self.act_list,
|
||||
"act_replace_rate": self.act_replace_rate,
|
||||
"agg_default": self.agg_default,
|
||||
"agg_list": self.agg_list,
|
||||
"agg_replace_rate": self.agg_replace_rate,
|
||||
"enabled_reverse_rate": self.enabled_reverse_rate,
|
||||
"add_node_rate": self.add_node_rate,
|
||||
"delete_node_rate": self.delete_node_rate,
|
||||
"add_connection_rate": self.add_connection_rate,
|
||||
"delete_connection_rate": self.delete_connection_rate,
|
||||
}
|
||||
|
||||
self.update_speciate_with_args = partial(
|
||||
create_next_generation_then_speciate,
|
||||
species_kwargs=species_kwargs,
|
||||
mutate_kwargs=mutate_kwargs
|
||||
)
|
||||
def get(self, name, symbols):
|
||||
if (name, hash_symbols(symbols)) not in self.func_dict:
|
||||
self.compile(name, symbols)
|
||||
return self.func_dict[name, hash_symbols(symbols)]
|
||||
|
||||
def create_update_speciate(self, N, C, S):
|
||||
key = ("update_speciate", N, C, S)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_update_speciate(N, C, S)
|
||||
return self.compiled_function[key]
|
||||
def compile(self, name, symbols):
|
||||
# prepare function prototype
|
||||
func = self.function_info[name]['func']
|
||||
|
||||
def compile_update_speciate(self, N, C, S):
|
||||
s = time.time()
|
||||
# prepare lower operands
|
||||
lowers_operands = []
|
||||
for lower in self.function_info[name]['lowers']:
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
shape[i] = symbols[s]
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
|
||||
func = self.update_speciate_with_args
|
||||
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
||||
pop_cons_lower = np.zeros((self.pop_size, C, 4))
|
||||
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
|
||||
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
|
||||
pre_spe_center_cons_lower = np.zeros((S, C, 4))
|
||||
species_keys = np.zeros((S,), dtype=np.int32)
|
||||
new_species_keys_lower = 0
|
||||
compiled_func = jit(func).lower(
|
||||
randkey_lower,
|
||||
pop_nodes_lower,
|
||||
pop_cons_lower,
|
||||
winner_part_lower,
|
||||
loser_part_lower,
|
||||
elite_mask_lower,
|
||||
new_node_keys_start_lower,
|
||||
pre_spe_center_nodes_lower,
|
||||
pre_spe_center_cons_lower,
|
||||
species_keys,
|
||||
new_species_keys_lower,
|
||||
).compile()
|
||||
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
||||
# compile
|
||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_with_args(self):
|
||||
self.topological_sort_with_args = topological_sort
|
||||
|
||||
def compile_topological_sort(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort(self, n):
|
||||
key = ('topological_sort', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_topological_sort(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_topological_sort_batch(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.topological_sort_with_args
|
||||
func = vmap(func)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('topological_sort_batch', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_topological_sort_batch(self, n):
|
||||
key = ('topological_sort_batch', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_topological_sort_batch(n)
|
||||
return self.compiled_function[key]
|
||||
|
||||
def create_single_forward_with_args(self):
|
||||
func = partial(
|
||||
forward_single,
|
||||
input_idx=self.input_idx,
|
||||
output_idx=self.output_idx
|
||||
)
|
||||
self.single_forward_with_args = func
|
||||
|
||||
|
||||
def compile_batch_forward(self, n):
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None))
|
||||
|
||||
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
||||
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
|
||||
nodes_lower = np.zeros((n, 5))
|
||||
connections_lower = np.zeros((2, n, n))
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_forward(self, n):
|
||||
key = ('batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_forward(n)
|
||||
|
||||
return self.compiled_function[key]
|
||||
|
||||
def compile_pop_batch_forward(self, n):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = self.single_forward_with_args
|
||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
||||
|
||||
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
|
||||
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
|
||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||
self.compiled_function[('pop_batch_forward', n)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_pop_batch_forward(self, n):
|
||||
key = ('pop_batch_forward', n)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_pop_batch_forward(n)
|
||||
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||
ts = self.create_topological_sort_batch(n)
|
||||
|
||||
# for connections with enabled is false, set weight to 0)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||
# print(pop_cal_seqs)
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
|
||||
|
||||
return debug_forward
|
||||
|
||||
def ask_batch_forward(self, nodes, connections):
|
||||
n = nodes.shape[0]
|
||||
ts = self.create_topological_sort(n)
|
||||
cal_seqs = ts(nodes, connections)
|
||||
forward_func = self.create_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||
|
||||
return debug_forward
|
||||
|
||||
def compile_batch_unflatten_connections(self, n, c):
|
||||
|
||||
s = time.time()
|
||||
|
||||
func = unflatten_connections
|
||||
func = vmap(func)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
pop_connections_lower = np.zeros((self.pop_size, c, 4))
|
||||
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||
|
||||
self.compile_time += time.time() - s
|
||||
|
||||
def create_batch_unflatten_connections(self, n, c):
|
||||
key = ('batch_unflatten_connections', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_unflatten_connections(n, c)
|
||||
|
||||
return self.compiled_function[key]
|
||||
# save for reuse
|
||||
self.func_dict[name, hash_symbols(symbols)] = compiled_func
|
||||
|
||||
@@ -104,11 +104,23 @@ def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
|
||||
@jit
|
||||
def act(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
||||
|
||||
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||
act_name2func = {
|
||||
'sigmoid': sigmoid_act,
|
||||
'tanh': tanh_act,
|
||||
'sin': sin_act,
|
||||
'gauss': gauss_act,
|
||||
'relu': relu_act,
|
||||
'elu': elu_act,
|
||||
'lelu': lelu_act,
|
||||
'selu': selu_act,
|
||||
'softplus': softplus_act,
|
||||
'identity': identity_act,
|
||||
'clamped': clamped_act,
|
||||
'inv': inv_act,
|
||||
'log': log_act,
|
||||
'exp': exp_act,
|
||||
'abs': abs_act,
|
||||
'hat': hat_act,
|
||||
'square': square_act,
|
||||
'cube': cube_act,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
"""
|
||||
aggregations, two special case need to consider:
|
||||
1. extra 0s
|
||||
2. full of 0s
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@@ -44,19 +38,13 @@ def maxabs_agg(z):
|
||||
|
||||
@jit
|
||||
def median_agg(z):
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_zero_mask, axis=0)
|
||||
non_nan_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_nan_mask, axis=0)
|
||||
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
sorted_valid_values = jnp.sort(z)
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
def _even_case():
|
||||
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
|
||||
|
||||
def _odd_case():
|
||||
return sorted_valid_values[n // 2]
|
||||
|
||||
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@@ -70,25 +58,12 @@ def mean_agg(z):
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
@jit
|
||||
def agg(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def full_nan():
|
||||
return 0.
|
||||
|
||||
def not_full_nan():
|
||||
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
||||
for names in agg_name2key.keys():
|
||||
print(names, agg(agg_name2key[names], array))
|
||||
|
||||
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
|
||||
for names in agg_name2key.keys():
|
||||
print(names, agg(agg_name2key[names], array2))
|
||||
agg_name2func = {
|
||||
'sum': sum_agg,
|
||||
'product': product_agg,
|
||||
'max': max_agg,
|
||||
'min': min_agg,
|
||||
'maxabs': maxabs_agg,
|
||||
'median': median_agg,
|
||||
'mean': mean_agg,
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from functools import partial
|
||||
"""
|
||||
Crossover two genomes to generate a new genome.
|
||||
The calculation method is the same as the crossover operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
@@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
@@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
@@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
# @jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
Crossover two genomes to generate a new genome.
|
||||
The calculation method is the same as the crossover operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
:param randkey:
|
||||
:param nodes1:
|
||||
:param cons1:
|
||||
:param nodes2:
|
||||
:param cons2:
|
||||
:return:
|
||||
"""
|
||||
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param gene_type:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if gene_type == 'connection':
|
||||
mask = jnp.all(mask, axis=2)
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
:param rand_key:
|
||||
:param g1:
|
||||
:param g2:
|
||||
:return:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
@@ -1,6 +1,9 @@
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
The calculation method is the same as the distance calculation in NEAT-python.
|
||||
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
@@ -9,26 +12,34 @@ from .utils import EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
|
||||
compatibility_coe: float = 0.5) -> Array:
|
||||
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||
args:
|
||||
nodes1: Array(N, 5)
|
||||
cons1: Array(C, 4)
|
||||
nodes2: Array(N, 5)
|
||||
cons2: Array(C, 4)
|
||||
returns:
|
||||
distance: Array(, )
|
||||
"""
|
||||
|
||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||
|
||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||
nd = node_distance(nodes1, nodes2, jit_config) # node distance
|
||||
cd = connection_distance(cons1, cons2, jit_config) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
@jit
|
||||
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between nodes of two genomes.
|
||||
"""
|
||||
# statistics nodes count of two genomes
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
@@ -36,19 +47,29 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
nd = batch_homologous_node_distance(fr, sr)
|
||||
nd = jnp.where(jnp.isnan(nd), 0, nd)
|
||||
homologous_distance = jnp.sum(nd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = vmap(homologous_node_distance)(fr, sr)
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
|
||||
|
||||
@jit
|
||||
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between connections of two genomes.
|
||||
Similar process as node_distance.
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
@@ -64,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
cd = batch_homologous_connection_distance(fr, sr)
|
||||
cd = jnp.where(jnp.isnan(cd), 0, cd)
|
||||
homologous_distance = jnp.sum(cd * intersect_mask)
|
||||
hcd = vmap(homologous_connection_distance)(fr, sr)
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_homologous_node_distance(b_n1, b_n2):
|
||||
return homologous_node_distance(b_n1, b_n2)
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_homologous_connection_distance(b_c1, b_c2):
|
||||
return homologous_connection_distance(b_c1, b_c2)
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_node_distance(n1, n2):
|
||||
def homologous_node_distance(n1: Array, n2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous nodes.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||
d += jnp.abs(n1[2] - n2[2]) # response
|
||||
d += n1[3] != n2[3] # activation
|
||||
d += n1[4] != n2[4]
|
||||
d += n1[4] != n2[4] # aggregation
|
||||
return d
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_connection_distance(c1, c2):
|
||||
def homologous_connection_distance(c1: Array, c2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous connections.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||
d += c1[3] != c2[3] # enable
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
The calculation method is the same as the distance calculation in NEAT-python.
|
||||
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from .utils import EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
args:
|
||||
nodes1: Array(N, 5)
|
||||
cons1: Array(C, 4)
|
||||
nodes2: Array(N, 5)
|
||||
cons2: Array(C, 4)
|
||||
returns:
|
||||
distance: Array(, )
|
||||
"""
|
||||
nd = node_distance(nodes1, nodes2, jit_config) # node distance
|
||||
cd = connection_distance(cons1, cons2, jit_config) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
@jit
|
||||
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between nodes of two genomes.
|
||||
"""
|
||||
# statistics nodes count of two genomes
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = vmap(homologous_node_distance)(fr, sr)
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
|
||||
|
||||
@jit
|
||||
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between connections of two genomes.
|
||||
Similar process as node_distance.
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = vmap(homologous_connection_distance)(fr, sr)
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_node_distance(n1: Array, n2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous nodes.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||
d += jnp.abs(n1[2] - n2[2]) # response
|
||||
d += n1[3] != n2[3] # activation
|
||||
d += n1[4] != n2[4] # aggregation
|
||||
return d
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_connection_distance(c1: Array, c2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous connections.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||
d += c1[3] != c2[3] # enable
|
||||
return d
|
||||
@@ -2,47 +2,82 @@ import jax
|
||||
from jax import Array, numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from .aggregations import agg
|
||||
from .activations import act
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||
@jit
|
||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||
input_idx: Array, output_idx: Array) -> Array:
|
||||
def create_forward(config):
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are single genome
|
||||
nodes, connections are a single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument input_idx: (input_num, )
|
||||
:argument output_idx: (output_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
def scan_body(carry, i):
|
||||
def hit():
|
||||
ins = carry * connections[0, :, i]
|
||||
z = agg(nodes[i, 4], ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
|
||||
|
||||
new_vals = carry.at[i].set(z)
|
||||
return new_vals
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return carry
|
||||
return values
|
||||
|
||||
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
return values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
|
||||
@@ -44,10 +44,13 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
pop_nodes[:, input_idx, 0] = input_idx
|
||||
pop_nodes[:, output_idx, 0] = output_idx
|
||||
|
||||
pop_nodes[:, output_idx, 1] = config['bias_init_mean']
|
||||
pop_nodes[:, output_idx, 2] = config['response_init_mean']
|
||||
pop_nodes[:, output_idx, 3] = config['activation_default']
|
||||
pop_nodes[:, output_idx, 4] = config['aggregation_default']
|
||||
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
|
||||
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
|
||||
size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
|
||||
size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
|
||||
|
||||
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
||||
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
||||
@@ -55,7 +58,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
p = config['num_inputs'] * config['num_outputs']
|
||||
pop_cons[:, :p, 0] = grid_a
|
||||
pop_cons[:, :p, 1] = grid_b
|
||||
pop_cons[:, :p, 2] = config['weight_init_mean']
|
||||
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
||||
size=(config['pop_size'], p))
|
||||
pop_cons[:, :p, 3] = 1
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
@@ -8,8 +8,7 @@ from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
# from .configs import fetch_first, I_INT
|
||||
from neat.genome.utils import fetch_first, I_INT
|
||||
from .utils import unflatten_connections
|
||||
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
|
||||
|
||||
|
||||
@jit
|
||||
@@ -44,49 +43,32 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
|
||||
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
||||
"""
|
||||
connections_enable = connections[1, :, :] == 1
|
||||
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
idx = 0
|
||||
|
||||
def scan_body(carry, _):
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
|
||||
def hit():
|
||||
# add to res and flag it is already in it
|
||||
new_res = res_.at[idx_].set(i)
|
||||
new_idx = idx_ + 1
|
||||
new_in_degree = in_degree_.at[i].set(-1)
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = connections_enable[i, :]
|
||||
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
|
||||
return new_res, new_idx, new_in_degree
|
||||
|
||||
def miss():
|
||||
return res_, idx_, in_degree_
|
||||
|
||||
return jax.lax.cond(i == I_INT, miss, hit), None
|
||||
|
||||
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
|
||||
res, _, _ = scan_res
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
@vmap
|
||||
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
batch version of topological_sort
|
||||
:param pop_nodes:
|
||||
:param pop_connections:
|
||||
:return:
|
||||
"""
|
||||
return topological_sort(pop_nodes, pop_connections)
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||
"""
|
||||
@@ -131,22 +113,26 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
|
||||
connections = unflatten_connections(nodes, connections)
|
||||
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
nodes_visited = jnp.full(nodes.shape[0], False)
|
||||
nodes_visited = nodes_visited.at[to_idx].set(True)
|
||||
|
||||
def scan_body(visited, _):
|
||||
new_visited = jnp.dot(visited, connections_enable)
|
||||
new_visited = jnp.logical_or(visited, new_visited)
|
||||
return new_visited, None
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
return nodes_visited[from_idx]
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,155 +1,64 @@
|
||||
from typing import Tuple
|
||||
"""
|
||||
Mutate a genome.
|
||||
The calculation method is the same as the mutation operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
|
||||
"""
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, vmap, Array
|
||||
from jax import jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .utils import fetch_random, fetch_first, I_INT
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
|
||||
|
||||
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
|
||||
@jit
|
||||
def mutate(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
new_node_key: int,
|
||||
input_idx: Array,
|
||||
output_idx: Array,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
bias_mutate_strength: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
response_mean: float = 1.,
|
||||
response_std: float = 0.,
|
||||
response_mutate_strength: float = 0.,
|
||||
response_mutate_rate: float = 0.,
|
||||
response_replace_rate: float = 0.,
|
||||
weight_mean: float = 0.,
|
||||
weight_std: float = 1.,
|
||||
weight_mutate_strength: float = 0.5,
|
||||
weight_mutate_rate: float = 0.7,
|
||||
weight_replace_rate: float = 0.1,
|
||||
act_default: int = 0,
|
||||
act_list: Array = None,
|
||||
act_replace_rate: float = 0.1,
|
||||
agg_default: int = 0,
|
||||
agg_list: Array = None,
|
||||
agg_replace_rate: float = 0.1,
|
||||
enabled_reverse_rate: float = 0.1,
|
||||
add_node_rate: float = 0.2,
|
||||
delete_node_rate: float = 0.2,
|
||||
add_connection_rate: float = 0.4,
|
||||
delete_connection_rate: float = 0.4,
|
||||
):
|
||||
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
|
||||
"""
|
||||
:param output_idx:
|
||||
:param input_idx:
|
||||
:param agg_default:
|
||||
:param act_default:
|
||||
:param rand_key:
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param new_node_key:
|
||||
:param bias_mean:
|
||||
:param bias_std:
|
||||
:param bias_mutate_strength:
|
||||
:param bias_mutate_rate:
|
||||
:param bias_replace_rate:
|
||||
:param response_mean:
|
||||
:param response_std:
|
||||
:param response_mutate_strength:
|
||||
:param response_mutate_rate:
|
||||
:param response_replace_rate:
|
||||
:param weight_mean:
|
||||
:param weight_std:
|
||||
:param weight_mutate_strength:
|
||||
:param weight_mutate_rate:
|
||||
:param weight_replace_rate:
|
||||
:param act_list:
|
||||
:param act_replace_rate:
|
||||
:param agg_list:
|
||||
:param agg_replace_rate:
|
||||
:param enabled_reverse_rate:
|
||||
:param add_node_rate:
|
||||
:param delete_node_rate:
|
||||
:param add_connection_rate:
|
||||
:param delete_connection_rate:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# structural mutations
|
||||
# mutate add node
|
||||
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||
r = rand(r1)
|
||||
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate add connection
|
||||
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||
r = rand(r2)
|
||||
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||
r = rand(r3)
|
||||
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||
r = rand(r4)
|
||||
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
|
||||
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||
weight_mean, weight_std, weight_mutate_strength,
|
||||
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
|
||||
agg_replace_rate, enabled_reverse_rate)
|
||||
# value mutations
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_values(rand_key: Array,
|
||||
nodes: Array,
|
||||
cons: Array,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
bias_mutate_strength: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
response_mean: float = 1.,
|
||||
response_std: float = 0.,
|
||||
response_mutate_strength: float = 0.,
|
||||
response_mutate_rate: float = 0.,
|
||||
response_replace_rate: float = 0.,
|
||||
weight_mean: float = 0.,
|
||||
weight_std: float = 1.,
|
||||
weight_mutate_strength: float = 0.5,
|
||||
weight_mutate_rate: float = 0.7,
|
||||
weight_replace_rate: float = 0.1,
|
||||
act_list: Array = None,
|
||||
act_replace_rate: float = 0.1,
|
||||
agg_list: Array = None,
|
||||
agg_replace_rate: float = 0.1,
|
||||
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
|
||||
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Mutate values of nodes and connections.
|
||||
|
||||
@@ -157,56 +66,48 @@ def mutate_values(rand_key: Array,
|
||||
rand_key: A random key for generating random values.
|
||||
nodes: A 2D array representing nodes.
|
||||
cons: A 3D array representing connections.
|
||||
bias_mean: Mean of the bias values.
|
||||
bias_std: Standard deviation of the bias values.
|
||||
bias_mutate_strength: Strength of the bias mutation.
|
||||
bias_mutate_rate: Rate of the bias mutation.
|
||||
bias_replace_rate: Rate of the bias replacement.
|
||||
response_mean: Mean of the response values.
|
||||
response_std: Standard deviation of the response values.
|
||||
response_mutate_strength: Strength of the response mutation.
|
||||
response_mutate_rate: Rate of the response mutation.
|
||||
response_replace_rate: Rate of the response replacement.
|
||||
weight_mean: Mean of the weight values.
|
||||
weight_std: Standard deviation of the weight values.
|
||||
weight_mutate_strength: Strength of the weight mutation.
|
||||
weight_mutate_rate: Rate of the weight mutation.
|
||||
weight_replace_rate: Rate of the weight replacement.
|
||||
act_list: List of the activation function values.
|
||||
act_replace_rate: Rate of the activation function replacement.
|
||||
agg_list: List of the aggregation function values.
|
||||
agg_replace_rate: Rate of the aggregation function replacement.
|
||||
enabled_reverse_rate: Rate of reversing enabled state of connections.
|
||||
jit_config: A dict containing configuration for jit-able functions.
|
||||
|
||||
Returns:
|
||||
A tuple containing mutated nodes and connections.
|
||||
"""
|
||||
|
||||
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
|
||||
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
|
||||
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
|
||||
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
||||
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
|
||||
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
||||
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
|
||||
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
|
||||
|
||||
# mutate enabled
|
||||
# bias
|
||||
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
|
||||
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
|
||||
jit_config['bias_replace_rate'])
|
||||
|
||||
# response
|
||||
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
|
||||
jit_config['response_init_std'], jit_config['response_mutate_power'],
|
||||
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
|
||||
|
||||
# weight
|
||||
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
|
||||
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
|
||||
jit_config['weight_replace_rate'])
|
||||
|
||||
# activation
|
||||
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
|
||||
jit_config['activation_replace_rate'])
|
||||
|
||||
# aggregation
|
||||
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
|
||||
jit_config['aggregation_replace_rate'])
|
||||
|
||||
# enabled
|
||||
r = jax.random.uniform(rand_key, cons[:, 3].shape)
|
||||
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
|
||||
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
|
||||
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
|
||||
|
||||
# merge
|
||||
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
|
||||
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
|
||||
|
||||
nodes = nodes.at[:, 1].set(bias_new)
|
||||
nodes = nodes.at[:, 2].set(response_new)
|
||||
nodes = nodes.at[:, 3].set(act_new)
|
||||
nodes = nodes.at[:, 4].set(agg_new)
|
||||
cons = cons.at[:, 2].set(weight_new)
|
||||
cons = cons.at[:, 3].set(enabled_new)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||
"""
|
||||
@@ -227,19 +128,26 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
||||
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
|
||||
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
|
||||
replace = jax.random.normal(k2, old_vals.shape) * std + mean
|
||||
|
||||
r = jax.random.uniform(k3, old_vals.shape)
|
||||
|
||||
# default
|
||||
new_vals = old_vals
|
||||
|
||||
# r in [0, mutate_rate), mutate
|
||||
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
|
||||
|
||||
# r in [mutate_rate, mutate_rate + replace_rate), replace
|
||||
new_vals = jnp.where(
|
||||
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
|
||||
replace,
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace + new_vals * 0.0, # in case of nan replace to values
|
||||
new_vals
|
||||
)
|
||||
|
||||
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate integer values (act, agg) of a given array.
|
||||
@@ -256,26 +164,20 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
||||
k1, k2, rand_key = jax.random.split(rand_key, num=3)
|
||||
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
|
||||
r = jax.random.uniform(k2, old_vals.shape)
|
||||
new_vals = old_vals
|
||||
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
|
||||
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
|
||||
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
|
||||
default_bias: float = 0, default_response: float = 1,
|
||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||
jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new node from splitting a connection.
|
||||
:param rand_key:
|
||||
:param new_node_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param default_bias:
|
||||
:param default_response:
|
||||
:param default_act:
|
||||
:param default_agg:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
@@ -287,12 +189,13 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_nodes, new_cons = nodes, cons
|
||||
|
||||
# set enable to false
|
||||
new_cons = new_cons.at[idx, 3].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes, new_cons = \
|
||||
add_node(new_nodes, new_cons, new_node_key,
|
||||
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
|
||||
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
|
||||
|
||||
# add two new connections
|
||||
w = new_cons[idx, 2]
|
||||
@@ -306,21 +209,18 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
|
||||
return nodes, cons
|
||||
|
||||
|
||||
# TODO: Need we really need to delete a node?
|
||||
@jit
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
# TODO: Do we really need to delete a node?
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a node
|
||||
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
|
||||
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
def nothing():
|
||||
@@ -328,37 +228,34 @@ def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
# delete all connections
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
|
||||
jnp.nan, aux_cons)
|
||||
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||
cycles are not allowed.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
@@ -375,15 +272,14 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||
return nodes, cons
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
unflattened = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
|
||||
|
||||
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
@@ -406,7 +302,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
@@ -435,7 +330,6 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
||||
return key, idx
|
||||
|
||||
|
||||
@jit
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
@@ -452,6 +346,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[A
|
||||
return i_key, o_key, idx
|
||||
|
||||
|
||||
@jit
|
||||
def rand(rand_key):
|
||||
return jax.random.uniform(rand_key, ())
|
||||
|
||||
@@ -1,355 +0,0 @@
|
||||
"""
|
||||
Mutate a genome.
|
||||
The calculation method is the same as the mutation operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
|
||||
"""
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT
|
||||
from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
|
||||
|
||||
@jit
|
||||
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
|
||||
"""
|
||||
:param rand_key:
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param new_node_key:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# structural mutations
|
||||
# mutate add node
|
||||
r = rand(r1)
|
||||
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate add connection
|
||||
r = rand(r2)
|
||||
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
r = rand(r3)
|
||||
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
r = rand(r4)
|
||||
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
|
||||
|
||||
# value mutations
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
|
||||
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Mutate values of nodes and connections.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
nodes: A 2D array representing nodes.
|
||||
cons: A 3D array representing connections.
|
||||
jit_config: A dict containing configuration for jit-able functions.
|
||||
|
||||
Returns:
|
||||
A tuple containing mutated nodes and connections.
|
||||
"""
|
||||
|
||||
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
|
||||
|
||||
# bias
|
||||
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
|
||||
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
|
||||
jit_config['bias_replace_rate'])
|
||||
|
||||
# response
|
||||
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
|
||||
jit_config['response_init_std'], jit_config['response_mutate_power'],
|
||||
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
|
||||
|
||||
# weight
|
||||
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
|
||||
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
|
||||
jit_config['weight_replace_rate'])
|
||||
|
||||
# activation
|
||||
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
|
||||
jit_config['activation_replace_rate'])
|
||||
|
||||
# aggregation
|
||||
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
|
||||
jit_config['aggregation_replace_rate'])
|
||||
|
||||
# enabled
|
||||
r = jax.random.uniform(rand_key, cons[:, 3].shape)
|
||||
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
|
||||
|
||||
# merge
|
||||
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
|
||||
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate float values of a given array.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
old_vals: A 1D array of float values to be mutated.
|
||||
mean: Mean of the values.
|
||||
std: Standard deviation of the values.
|
||||
mutate_strength: Strength of the mutation.
|
||||
mutate_rate: Rate of the mutation.
|
||||
replace_rate: Rate of the replacement.
|
||||
|
||||
Returns:
|
||||
A mutated 1D array of float values.
|
||||
"""
|
||||
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
|
||||
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
|
||||
replace = jax.random.normal(k2, old_vals.shape) * std + mean
|
||||
|
||||
r = jax.random.uniform(k3, old_vals.shape)
|
||||
|
||||
# default
|
||||
new_vals = old_vals
|
||||
|
||||
# r in [0, mutate_rate), mutate
|
||||
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
|
||||
|
||||
# r in [mutate_rate, mutate_rate + replace_rate), replace
|
||||
new_vals = jnp.where(
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace + new_vals * 0.0, # in case of nan replace to values
|
||||
new_vals
|
||||
)
|
||||
|
||||
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||
return new_vals
|
||||
|
||||
|
||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate integer values (act, agg) of a given array.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
old_vals: A 1D array of integer values to be mutated.
|
||||
val_list: List of the integer values.
|
||||
replace_rate: Rate of the replacement.
|
||||
|
||||
Returns:
|
||||
A mutated 1D array of integer values.
|
||||
"""
|
||||
k1, k2, rand_key = jax.random.split(rand_key, num=3)
|
||||
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
|
||||
r = jax.random.uniform(k2, old_vals.shape)
|
||||
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
|
||||
|
||||
return new_vals
|
||||
|
||||
|
||||
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
|
||||
jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new node from splitting a connection.
|
||||
:param rand_key:
|
||||
:param new_node_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing(): # there is no connection to split
|
||||
return nodes, cons
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_nodes, new_cons = nodes, cons
|
||||
|
||||
# set enable to false
|
||||
new_cons = new_cons.at[idx, 3].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
|
||||
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
|
||||
|
||||
# add two new connections
|
||||
w = new_cons[idx, 2]
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
# TODO: Do we really need to delete a node?
|
||||
@jit
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a node
|
||||
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
def nothing():
|
||||
return nodes, cons
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
# delete all connections
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
|
||||
jnp.nan, aux_cons)
|
||||
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||
cycles are not allowed.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
def successful():
|
||||
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
def already_exist():
|
||||
new_cons = cons.at[con_idx, 3].set(True)
|
||||
return nodes, new_cons
|
||||
|
||||
def cycle():
|
||||
return nodes, cons
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
|
||||
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing():
|
||||
return nodes, cons
|
||||
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
|
||||
|
||||
idx = fetch_random(rand_key, mask)
|
||||
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
|
||||
@jit
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
|
||||
|
||||
@jit
|
||||
def rand(rand_key):
|
||||
return jax.random.uniform(rand_key, ())
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
@@ -11,20 +10,18 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def unflatten_connections(nodes, cons):
|
||||
def unflatten_connections(nodes: Array, cons: Array):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
|
||||
connections to nan, that means we dont consider such connection when forward;
|
||||
:param cons:
|
||||
:param nodes:
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_idxs = key_to_indices(i_keys, node_keys)
|
||||
o_idxs = key_to_indices(o_keys, node_keys)
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
@@ -34,8 +31,6 @@ def unflatten_connections(nodes, cons):
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, None))
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
@@ -46,27 +41,12 @@ def fetch_first(mask, default=I_INT) -> Array:
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
|
||||
example:
|
||||
>>> a = jnp.array([1, 2, 3, 4, 5])
|
||||
>>> fetch_first(a > 3)
|
||||
3
|
||||
>>> fetch_first(a > 30)
|
||||
I_INT
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_last(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch the last True index
|
||||
"""
|
||||
reversed_idx = fetch_first(mask[::-1], default)
|
||||
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
"""
|
||||
@@ -78,27 +58,8 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
print(fetch_first(a > 3))
|
||||
print(fetch_first(a > 30))
|
||||
|
||||
print(fetch_last(a > 3))
|
||||
print(fetch_last(a > 30))
|
||||
|
||||
rand_key = jax.random.PRNGKey(0)
|
||||
|
||||
for t in [-1, 0, 1, 2, 3, 4, 5]:
|
||||
for _ in range(10):
|
||||
rand_key, _ = jax.random.split(rand_key)
|
||||
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
|
||||
print(t, fetch_random(rand_key, a > t))
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
|
||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def unflatten_connections(nodes: Array, cons: Array):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_idxs = vmap(fetch_first, in_axes=(0, None))(i_keys, node_keys)
|
||||
i_idxs = key_to_indices(i_keys, node_keys)
|
||||
o_idxs = key_to_indices(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, None))
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
|
||||
example:
|
||||
>>> a = jnp.array([1, 2, 3, 4, 5])
|
||||
>>> fetch_first(a > 3)
|
||||
3
|
||||
>>> fetch_first(a > 30)
|
||||
I_INT
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_last(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch the last True index
|
||||
"""
|
||||
reversed_idx = fetch_first(mask[::-1], default)
|
||||
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
print(fetch_first(a > 3))
|
||||
print(fetch_first(a > 30))
|
||||
|
||||
print(fetch_last(a > 3))
|
||||
print(fetch_last(a > 30))
|
||||
|
||||
rand_key = jax.random.PRNGKey(0)
|
||||
|
||||
for t in [-1, 0, 1, 2, 3, 4, 5]:
|
||||
for _ in range(10):
|
||||
rand_key, _ = jax.random.split(rand_key)
|
||||
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
|
||||
print(t, fetch_random(rand_key, a > t))
|
||||
180
neat/pipeline.py
180
neat/pipeline.py
@@ -1,158 +1,78 @@
|
||||
from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from configs.configer import Configer
|
||||
from .genome.genome import initialize_genomes
|
||||
from .function_factory import FunctionFactory
|
||||
|
||||
from .population import *
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, function_factory, seed=42):
|
||||
self.time_dict = {}
|
||||
self.function_factory = function_factory
|
||||
|
||||
def __init__(self, config, function_factory=None, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
self.C = config.basic.init_maximum_connections
|
||||
self.S = config.basic.init_maximum_species
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.function_factory = function_factory or FunctionFactory(self.config)
|
||||
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
|
||||
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
|
||||
|
||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||
self.symbols = {
|
||||
'P': self.config['pop_size'],
|
||||
'N': self.config['init_maximum_nodes'],
|
||||
'C': self.config['init_maximum_connections'],
|
||||
'S': self.config['init_maximum_species'],
|
||||
}
|
||||
|
||||
self.generation = 0
|
||||
self.generation_time_list = []
|
||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_genome = None
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
|
||||
|
||||
|
||||
def ask(self):
|
||||
"""
|
||||
Create a forward function for the population.
|
||||
:return:
|
||||
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
||||
Creates a function that receives a genome and returns a forward function.
|
||||
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
|
||||
|
||||
single:
|
||||
Create pop_size number of forward functions.
|
||||
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
||||
e.g. RL task
|
||||
|
||||
pop:
|
||||
Create a single forward function, which use only once calculation for the population.
|
||||
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
|
||||
common:
|
||||
Special case of pop. The population has the same inputs.
|
||||
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
e.g. numerical regression; Hyper-NEAT
|
||||
|
||||
"""
|
||||
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons)
|
||||
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
|
||||
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
|
||||
|
||||
def tell(self, fitnesses):
|
||||
if self.config['forward_way'] == 'single':
|
||||
forward_funcs = []
|
||||
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
|
||||
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
|
||||
forward_funcs.append(func)
|
||||
return forward_funcs
|
||||
|
||||
self.generation += 1
|
||||
elif self.config['forward_way'] == 'pop':
|
||||
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask(
|
||||
fitnesses,
|
||||
self.generation,
|
||||
self.S, self.N, self.C)
|
||||
elif self.config['forward_way'] == 'common':
|
||||
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
return func
|
||||
|
||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate(
|
||||
self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask,
|
||||
new_node_keys,
|
||||
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
|
||||
|
||||
|
||||
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
|
||||
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
||||
|
||||
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
|
||||
|
||||
self.expand()
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config.neat.population.generation_limit):
|
||||
forward_func = self.ask()
|
||||
|
||||
tic = time.time()
|
||||
fitnesses = fitness_func(forward_func)
|
||||
self.evaluate_time += time.time() - tic
|
||||
|
||||
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config.neat.population.fitness_threshold:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
:return:
|
||||
when the maximum node number of the population >= N
|
||||
the population will expand
|
||||
"""
|
||||
pop_node_keys = self.pop_nodes[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
if max_node_size >= self.N:
|
||||
self.N = int(self.N * self.expand_coe)
|
||||
# self.C = int(self.C * self.expand_coe)
|
||||
print(f"node expand to {self.N}!")
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
|
||||
pop_con_keys = self.pop_cons[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.C:
|
||||
# self.N = int(self.N * self.expand_coe)
|
||||
self.C = int(self.C * self.expand_coe)
|
||||
print(f"connections expand to {self.C}!")
|
||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||
|
||||
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_time_list.append(cost_time)
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
raise NotImplementedError
|
||||
def get_func(self, name):
|
||||
return self.function_factory.get(name, self.symbols)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import jax
|
||||
|
||||
from configs.configer import Configer
|
||||
from .genome.genome_ import initialize_genomes
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
self.N = self.config["init_maximum_nodes"]
|
||||
self.C = self.config["init_maximum_connections"]
|
||||
self.S = self.config["init_maximum_species"]
|
||||
|
||||
self.generation = 0
|
||||
self.best_genome = None
|
||||
|
||||
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
|
||||
|
||||
print(self.pop_nodes, self.pop_cons, sep='\n')
|
||||
print(self.jit_config)
|
||||
Reference in New Issue
Block a user