finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -1,32 +0,0 @@
from neat.genome.activations import *
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
def refactor_act(config):
config['activation_default'] = act_name2key[config['activation_default']]
config['activation_options'] = [
act_name2key[act_name] for act_name in config['activation_options']
]

View File

@@ -1,20 +0,0 @@
from neat.genome.aggregations import *
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
def refactor_agg(config):
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
config['aggregation_options'] = [
agg_name2key[act_name] for act_name in config['aggregation_options']
]

View File

@@ -4,8 +4,8 @@ import configparser
import numpy as np
from .activations import refactor_act
from .aggregations import refactor_agg
from neat.genome.activations import act_name2func
from neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
@@ -20,12 +20,12 @@ jit_config_keys = [
"node_delete_prob",
"compatibility_threshold",
"bias_init_mean",
"bias_init_stdev",
"bias_init_std",
"bias_mutate_power",
"bias_mutate_rate",
"bias_replace_rate",
"response_init_mean",
"response_init_stdev",
"response_init_std",
"response_mutate_power",
"response_mutate_rate",
"response_replace_rate",
@@ -36,7 +36,7 @@ jit_config_keys = [
"aggregation_options",
"aggregation_replace_rate",
"weight_init_mean",
"weight_init_stdev",
"weight_init_std",
"weight_mutate_power",
"weight_mutate_rate",
"weight_replace_rate",
@@ -90,14 +90,26 @@ class Configer:
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
refactor_act(config)
refactor_agg(config)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
config['input_idx'] = input_idx
config['output_idx'] = output_idx
cls.refactor_activation(config)
cls.refactor_aggregation(config)
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}

View File

@@ -5,7 +5,8 @@ init_maximum_nodes = 20
init_maximum_connections = 20
init_maximum_species = 10
expands_coe = 2.0
forward_way = "pop_batch"
forward_way = "pop"
batch_size = 4
[population]
fitness_threshold = 100000
@@ -46,12 +47,12 @@ response_replace_rate = 0.0
[gene-activation]
activation_default = "sigmoid"
activation_options = ["sigmoid"]
activation_option_names = ["sigmoid"]
activation_replace_rate = 0.0
[gene-aggregation]
aggregation_default = "sum"
aggregation_options = ["sum"]
aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0
[gene-weight]

View File

@@ -3,6 +3,9 @@ import numpy as np
import jax.numpy as jnp
import jax
a = {1:2, 2:3, 4:5}
print(a.values())
a = jnp.array([1, 0, 1, 0, np.nan])
b = jnp.array([1, 1, 1, 1, 1])
c = jnp.array([1, 1, 1, 1, 1])
@@ -44,5 +47,9 @@ def func(x):
else:
return 2
a = jnp.zeros((3, 3))
print(a.dtype)
print(main())
c = None
b = 1 or c
print(b)

View File

@@ -1,16 +1,25 @@
from functools import partial
import numpy as np
import jax
from jax import jit
from configs import Configer
from neat.pipeline_ import Pipeline
from neat.pipeline import Pipeline
from neat.function_factory import FunctionFactory
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def main():
config = Configer.load_config("xor.ini")
print(config)
pipeline = Pipeline(config)
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory)
forward_func = pipeline.ask()
# inputs = np.tile(xor_inputs, (150, 1, 1))
outputs = forward_func(xor_inputs)
print(outputs)
@jit

View File

@@ -1,28 +0,0 @@
import cProfile
from io import StringIO
import pstats
def using_cprofile(func, root_abs_path=None, replace_pattern=None, save_path=None):
def inner(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
ret = func(*args, **kwargs)
pr.disable()
profile_stats = StringIO()
stats = pstats.Stats(pr, stream=profile_stats)
if root_abs_path is not None:
stats.sort_stats('cumulative').print_stats(root_abs_path)
else:
stats.sort_stats('cumulative').print_stats()
output = profile_stats.getvalue()
if replace_pattern is not None:
output = output.replace(replace_pattern, "")
if save_path is None:
print(output)
else:
with open(save_path, "w") as f:
f.write(output)
return ret
return inner

View File

@@ -1,2 +1,5 @@
[basic]
forward_way = "common"
[population]
fitness_threshold = -1e-2

View File

@@ -1,323 +1,108 @@
"""
Lowers, compiles, and creates functions used in the NEAT pipeline.
"""
from functools import partial
import time
import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes
from .genome import topological_sort, forward_single, unflatten_connections
from .population import create_next_generation_then_speciate
from .genome.forward import create_forward
from .genome.utils import unflatten_connections
from .genome.graph import topological_sort
def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
class FunctionFactory:
"""
Creates and compiles functions used in the NEAT pipeline.
"""
def __init__(self, config):
self.config = config
self.func_dict = {}
self.function_info = {}
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.compile_time = 0
# (inputs_nums, ) -> (outputs_nums, )
forward = create_forward(config) # input size (inputs_nums, )
self.load_config_vals(config)
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
self.create_update_speciate_with_args()
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
def load_config_vals(self, config):
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
self.function_info = {
"pop_unflatten_connections": {
'func': vmap(unflatten_connections),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 'C', 4), 'type': np.float32}
]
},
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
"pop_topological_sort": {
'func': vmap(topological_sort),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
]
},
self.num_inputs = config.basic.num_inputs
self.num_outputs = config.basic.num_outputs
self.input_idx = np.arange(self.num_inputs)
self.output_idx = np.arange(self.num_inputs, self.num_inputs + self.num_outputs)
"batch_forward": {
'func': batch_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('N', ), 'type': np.int32},
{'shape': ('N', 5), 'type': np.float32},
{'shape': (2, 'N', 'N'), 'type': np.float32}
]
},
bias = config.neat.gene.bias
self.bias_mean = bias.init_mean
self.bias_std = bias.init_stdev
self.bias_mutate_strength = bias.mutate_power
self.bias_mutate_rate = bias.mutate_rate
self.bias_replace_rate = bias.replace_rate
"pop_batch_forward": {
'func': pop_batch_forward,
'lowers': [
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
response = config.neat.gene.response
self.response_mean = response.init_mean
self.response_std = response.init_stdev
self.response_mutate_strength = response.mutate_power
self.response_mutate_rate = response.mutate_rate
self.response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
self.weight_mean = weight.init_mean
self.weight_std = weight.init_stdev
self.weight_mutate_strength = weight.mutate_power
self.weight_mutate_rate = weight.mutate_rate
self.weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
self.act_default = act_name2key[activation.default]
self.act_list = np.array([act_name2key[name] for name in activation.options])
self.act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
self.agg_default = agg_name2key[aggregation.default]
self.agg_list = np.array([agg_name2key[name] for name in aggregation.options])
self.agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
self.enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
self.add_node_rate = genome.node_add_prob
self.delete_node_rate = genome.node_delete_prob
self.add_connection_rate = genome.conn_add_prob
self.delete_connection_rate = genome.conn_delete_prob
self.single_structure_mutate = genome.single_structural_mutation
def create_initialize(self, N, C):
func = partial(
initialize_genomes,
pop_size=self.pop_size,
N=N,
C=C,
num_inputs=self.num_inputs,
num_outputs=self.num_outputs,
default_bias=self.bias_mean,
default_response=self.response_mean,
default_act=self.act_default,
default_agg=self.agg_default,
default_weight=self.weight_mean
)
return func
def create_update_speciate_with_args(self):
species_kwargs = {
"disjoint_coe": self.disjoint_coe,
"compatibility_coe": self.compatibility_coe,
"compatibility_threshold": self.compatibility_threshold
'common_forward': {
'func': common_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
}
}
mutate_kwargs = {
"input_idx": self.input_idx,
"output_idx": self.output_idx,
"bias_mean": self.bias_mean,
"bias_std": self.bias_std,
"bias_mutate_strength": self.bias_mutate_strength,
"bias_mutate_rate": self.bias_mutate_rate,
"bias_replace_rate": self.bias_replace_rate,
"response_mean": self.response_mean,
"response_std": self.response_std,
"response_mutate_strength": self.response_mutate_strength,
"response_mutate_rate": self.response_mutate_rate,
"response_replace_rate": self.response_replace_rate,
"weight_mean": self.weight_mean,
"weight_std": self.weight_std,
"weight_mutate_strength": self.weight_mutate_strength,
"weight_mutate_rate": self.weight_mutate_rate,
"weight_replace_rate": self.weight_replace_rate,
"act_default": self.act_default,
"act_list": self.act_list,
"act_replace_rate": self.act_replace_rate,
"agg_default": self.agg_default,
"agg_list": self.agg_list,
"agg_replace_rate": self.agg_replace_rate,
"enabled_reverse_rate": self.enabled_reverse_rate,
"add_node_rate": self.add_node_rate,
"delete_node_rate": self.delete_node_rate,
"add_connection_rate": self.add_connection_rate,
"delete_connection_rate": self.delete_connection_rate,
}
self.update_speciate_with_args = partial(
create_next_generation_then_speciate,
species_kwargs=species_kwargs,
mutate_kwargs=mutate_kwargs
)
def get(self, name, symbols):
if (name, hash_symbols(symbols)) not in self.func_dict:
self.compile(name, symbols)
return self.func_dict[name, hash_symbols(symbols)]
def create_update_speciate(self, N, C, S):
key = ("update_speciate", N, C, S)
if key not in self.compiled_function:
self.compile_update_speciate(N, C, S)
return self.compiled_function[key]
def compile(self, name, symbols):
# prepare function prototype
func = self.function_info[name]['func']
def compile_update_speciate(self, N, C, S):
s = time.time()
# prepare lower operands
lowers_operands = []
for lower in self.function_info[name]['lowers']:
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
func = self.update_speciate_with_args
randkey_lower = np.zeros((2,), dtype=np.uint32)
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
pop_cons_lower = np.zeros((self.pop_size, C, 4))
winner_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
loser_part_lower = np.zeros((self.pop_size,), dtype=np.int32)
elite_mask_lower = np.zeros((self.pop_size,), dtype=bool)
new_node_keys_start_lower = np.zeros((self.pop_size,), dtype=np.int32)
pre_spe_center_nodes_lower = np.zeros((S, N, 5))
pre_spe_center_cons_lower = np.zeros((S, C, 4))
species_keys = np.zeros((S,), dtype=np.int32)
new_species_keys_lower = 0
compiled_func = jit(func).lower(
randkey_lower,
pop_nodes_lower,
pop_cons_lower,
winner_part_lower,
loser_part_lower,
elite_mask_lower,
new_node_keys_start_lower,
pre_spe_center_nodes_lower,
pre_spe_center_cons_lower,
species_keys,
new_species_keys_lower,
).compile()
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
# compile
compiled_func = jit(func).lower(*lowers_operands).compile()
self.compile_time += time.time() - s
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
def compile_topological_sort(self, n):
s = time.time()
func = self.topological_sort_with_args
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
self.compile_time += time.time() - s
def create_topological_sort(self, n):
key = ('topological_sort', n)
if key not in self.compiled_function:
self.compile_topological_sort(n)
return self.compiled_function[key]
def compile_topological_sort_batch(self, n):
s = time.time()
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort_batch', n)] = func
self.compile_time += time.time() - s
def create_topological_sort_batch(self, n):
key = ('topological_sort_batch', n)
if key not in self.compiled_function:
self.compile_topological_sort_batch(n)
return self.compiled_function[key]
def create_single_forward_with_args(self):
func = partial(
forward_single,
input_idx=self.input_idx,
output_idx=self.output_idx
)
self.single_forward_with_args = func
def compile_batch_forward(self, n):
s = time.time()
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None))
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
self.compile_time += time.time() - s
def create_batch_forward(self, n):
key = ('batch_forward', n)
if key not in self.compiled_function:
self.compile_batch_forward(n)
return self.compiled_function[key]
def compile_pop_batch_forward(self, n):
s = time.time()
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_batch_forward', n)] = func
self.compile_time += time.time() - s
def create_pop_batch_forward(self, n):
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
return self.compiled_function[key]
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
n, c = pop_nodes.shape[1], pop_cons.shape[1]
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
ts = self.create_topological_sort_batch(n)
# for connections with enabled is false, set weight to 0)
pop_cal_seqs = ts(pop_nodes, pop_cons)
# print(pop_cal_seqs)
forward_func = self.create_pop_batch_forward(n)
def debug_forward(inputs):
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
return debug_forward
def ask_batch_forward(self, nodes, connections):
n = nodes.shape[0]
ts = self.create_topological_sort(n)
cal_seqs = ts(nodes, connections)
forward_func = self.create_batch_forward(n)
def debug_forward(inputs):
return forward_func(inputs, cal_seqs, nodes, connections)
return debug_forward
def compile_batch_unflatten_connections(self, n, c):
s = time.time()
func = unflatten_connections
func = vmap(func)
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
pop_connections_lower = np.zeros((self.pop_size, c, 4))
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
self.compiled_function[('batch_unflatten_connections', n, c)] = func
self.compile_time += time.time() - s
def create_batch_unflatten_connections(self, n, c):
key = ('batch_unflatten_connections', n, c)
if key not in self.compiled_function:
self.compile_batch_unflatten_connections(n, c)
return self.compiled_function[key]
# save for reuse
self.func_dict[name, hash_symbols(symbols)] = compiled_func

View File

@@ -104,11 +104,23 @@ def cube_act(z):
return z ** 3
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
act_name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}

View File

@@ -1,9 +1,3 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import jax
import jax.numpy as jnp
import numpy as np
@@ -44,19 +38,13 @@ def maxabs_agg(z):
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
z = jnp.where(jnp.isnan(z), jnp.inf, z)
sorted_valid_values = jnp.sort(z)
z = jnp.sort(z) # sort
def _even_case():
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
def _odd_case():
return sorted_valid_values[n // 2]
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@@ -70,25 +58,12 @@ def mean_agg(z):
return mean_without_zeros
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_nan():
return 0.
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))
agg_name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}

View File

@@ -1,14 +1,17 @@
from functools import partial
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import jit, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]:
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
@@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
@@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
@@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes

View File

@@ -1,81 +0,0 @@
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,6 +1,9 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
@@ -9,26 +12,34 @@ from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
@@ -36,19 +47,29 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
@@ -64,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1, c2):
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable

View File

@@ -1,119 +0,0 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -2,47 +2,82 @@ import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from .aggregations import agg
from .activations import act
from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
def create_forward(config):
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
:argument inputs: (input_num, )
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
:return (output_num, )
"""
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def all_nan():
return 0.
def scan_body(carry, i):
def hit():
ins = carry * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
new_vals = carry.at[i].set(z)
return new_vals
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def miss():
return carry
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
:return (output_num, )
"""
return vals[output_idx]
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward

View File

@@ -44,10 +44,13 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 2] = config['response_init_mean']
pop_nodes[:, output_idx, 3] = config['activation_default']
pop_nodes[:, output_idx, 4] = config['aggregation_default']
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
@@ -55,7 +58,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
p = config['num_inputs'] * config['num_outputs']
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = config['weight_init_mean']
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons

View File

@@ -8,8 +8,7 @@ from jax import jit, vmap, Array
from jax import numpy as jnp
# from .configs import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT
from .utils import unflatten_connections
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
@jit
@@ -44,49 +43,32 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
def scan_body(carry, _):
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
def hit():
# add to res and flag it is already in it
new_res = res_.at[idx_].set(i)
new_idx = idx_ + 1
new_in_degree = in_degree_.at[i].set(-1)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
return new_res, new_idx, new_in_degree
def miss():
return res_, idx_, in_degree_
return jax.lax.cond(i == I_INT, miss, hit), None
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
res, _, _ = scan_res
# decrease in_degree of all its children
children = connections_enable[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
return topological_sort(pop_nodes, pop_connections)
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
@@ -131,22 +113,26 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 1, 0) -> False
"""
connections = unflatten_connections(nodes, connections)
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)
def scan_body(visited, _):
new_visited = jnp.dot(visited, connections_enable)
new_visited = jnp.logical_or(visited, new_visited)
return new_visited, None
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
return nodes_visited[from_idx]
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':

View File

@@ -1,155 +1,64 @@
from typing import Tuple
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jax import jit, vmap, Array
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
new_node_key: int,
input_idx: Array,
output_idx: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
):
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param output_idx:
:param input_idx:
:param agg_default:
:param act_default:
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param bias_mean:
:param bias_std:
:param bias_mutate_strength:
:param bias_mutate_rate:
:param bias_replace_rate:
:param response_mean:
:param response_std:
:param response_mutate_strength:
:param response_mutate_rate:
:param response_replace_rate:
:param weight_mean:
:param weight_std:
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_list:
:param act_replace_rate:
:param agg_list:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param jit_config:
:return:
"""
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
agg_replace_rate, enabled_reverse_rate)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
@jit
def mutate_values(rand_key: Array,
nodes: Array,
cons: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
@@ -157,56 +66,48 @@ def mutate_values(rand_key: Array,
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation.
bias_mutate_rate: Rate of the bias mutation.
bias_replace_rate: Rate of the bias replacement.
response_mean: Mean of the response values.
response_std: Standard deviation of the response values.
response_mutate_strength: Strength of the response mutation.
response_mutate_rate: Rate of the response mutation.
response_replace_rate: Rate of the response replacement.
weight_mean: Mean of the weight values.
weight_std: Standard deviation of the weight values.
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_list: List of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_list: List of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# mutate enabled
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new)
cons = cons.at[:, 2].set(weight_new)
cons = cons.at[:, 3].set(enabled_new)
return nodes, cons
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
@@ -227,19 +128,26 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
@@ -256,26 +164,20 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
@jit
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:param jit_config:
:return:
"""
# randomly choose a connection
@@ -287,12 +189,13 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = \
add_node(new_nodes, new_cons, new_node_key,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
@@ -306,59 +209,53 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
return nodes, cons
# TODO: Need we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
# TODO: Do we really need to delete a node?
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:param jit_config:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
@@ -375,15 +272,14 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
return nodes, cons
is_already_exist = con_idx != I_INT
unflattened = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
@@ -406,7 +302,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
return nodes, cons
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -435,7 +330,6 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
@@ -452,6 +346,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[A
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -1,355 +0,0 @@
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT
from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@jit
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param jit_config:
:return:
"""
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
return nodes, cons
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Do we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a node
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -1,5 +1,4 @@
from functools import partial
from typing import Tuple
import jax
from jax import numpy as jnp, Array
@@ -11,20 +10,18 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes, cons):
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
@@ -34,8 +31,6 @@ def unflatten_connections(nodes, cons):
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@@ -46,27 +41,12 @@ def fetch_first(mask, default=I_INT) -> Array:
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
@@ -78,27 +58,8 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))

View File

@@ -1,102 +0,0 @@
from functools import partial
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = vmap(fetch_first, in_axes=(0, None))(i_keys, node_keys)
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))

View File

@@ -1,158 +1,78 @@
from typing import List, Union, Tuple, Callable
import time
from functools import partial
import jax
import numpy as np
import jax
from .species import SpeciesController
from .genome import expand, expand_single
from configs.configer import Configer
from .genome.genome import initialize_genomes
from .function_factory import FunctionFactory
from .population import *
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, function_factory, seed=42):
self.time_dict = {}
self.function_factory = function_factory
def __init__(self, config, function_factory=None, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config
self.N = config.basic.init_maximum_nodes
self.C = config.basic.init_maximum_connections
self.S = config.basic.init_maximum_species
self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.function_factory = function_factory or FunctionFactory(self.config)
self.species_controller = SpeciesController(config)
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.symbols = {
'P': self.config['pop_size'],
'N': self.config['init_maximum_nodes'],
'C': self.config['init_maximum_connections'],
'S': self.config['init_maximum_species'],
}
self.generation = 0
self.generation_time_list = []
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
self.pop_nodes, self.pop_cons = initialize_genomes(self.symbols['N'], self.symbols['C'], self.config)
def ask(self):
"""
Create a forward function for the population.
:return:
Algorithm gives the population a forward function, then environment gives back the fitnesses.
Creates a function that receives a genome and returns a forward function.
There are 3 types of config['forward_way']: {'single', 'pop', 'common'}
single:
Create pop_size number of forward functions.
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
e.g. RL task
pop:
Create a single forward function, which use only once calculation for the population.
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
common:
Special case of pop. The population has the same inputs.
The function receives (batch_size, input_size) and returns (pop_size, batch_size, output_size)
e.g. numerical regression; Hyper-NEAT
"""
return self.function_factory.ask_pop_batch_forward(self.pop_nodes, self.pop_cons)
u_pop_cons = self.get_func('pop_unflatten_connections')(self.pop_nodes, self.pop_cons)
pop_seqs = self.get_func('pop_topological_sort')(self.pop_nodes, u_pop_cons)
def tell(self, fitnesses):
if self.config['forward_way'] == 'single':
forward_funcs = []
for seq, nodes, cons in zip(pop_seqs, self.pop_nodes, u_pop_cons):
func = lambda x: self.get_func('forward')(x, seq, nodes, cons)
forward_funcs.append(func)
return forward_funcs
self.generation += 1
elif self.config['forward_way'] == 'pop':
func = lambda x: self.get_func('pop_batch_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
winner_part, loser_part, elite_mask, pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start = self.species_controller.ask(
fitnesses,
self.generation,
self.S, self.N, self.C)
elif self.config['forward_way'] == 'common':
func = lambda x: self.get_func('common_forward')(x, pop_seqs, self.pop_nodes, u_pop_cons)
return func
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = self.create_and_speciate(
self.randkey, self.pop_nodes, self.pop_cons, winner_part, loser_part, elite_mask,
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
self.expand()
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config.neat.population.fitness_threshold:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def expand(self):
"""
Expand the population if needed.
:return:
when the maximum node number of the population >= N
the population will expand
"""
pop_node_keys = self.pop_nodes[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe)
# self.C = int(self.C * self.expand_coe)
print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C)
pop_con_keys = self.pop_cons[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C:
# self.N = int(self.N * self.expand_coe)
self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C)
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_time_list.append(cost_time)
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.pop_nodes[max_idx], self.pop_cons[max_idx])
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
else:
raise NotImplementedError
def get_func(self, name):
return self.function_factory.get(name, self.symbols)

View File

@@ -1,27 +0,0 @@
import jax
from configs.configer import Configer
from .genome.genome_ import initialize_genomes
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.N = self.config["init_maximum_nodes"]
self.C = self.config["init_maximum_connections"]
self.S = self.config["init_maximum_species"]
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons = initialize_genomes(self.N, self.C, self.config)
print(self.pop_nodes, self.pop_cons, sep='\n')
print(self.jit_config)