refactor genome.py use (C, 4) to replace (2, N, N) to represent connections
faster, faster and faster!
This commit is contained in:
@@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
|
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
|
||||||
from .genome import topological_sort, forward_single
|
from .genome import topological_sort, forward_single, unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
class FunctionFactory:
|
class FunctionFactory:
|
||||||
@@ -17,19 +17,18 @@ class FunctionFactory:
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
self.init_N = config.basic.init_maximum_nodes
|
self.init_N = config.basic.init_maximum_nodes
|
||||||
|
self.init_C = config.basic.init_maximum_connections
|
||||||
self.expand_coe = config.basic.expands_coe
|
self.expand_coe = config.basic.expands_coe
|
||||||
self.precompile_times = config.basic.pre_compile_times
|
self.precompile_times = config.basic.pre_compile_times
|
||||||
self.compiled_function = {}
|
self.compiled_function = {}
|
||||||
|
|
||||||
self.load_config_vals(config)
|
self.load_config_vals(config)
|
||||||
self.precompile()
|
self.precompile()
|
||||||
pass
|
|
||||||
|
|
||||||
def load_config_vals(self, config):
|
def load_config_vals(self, config):
|
||||||
self.problem_batch = config.basic.problem_batch
|
self.problem_batch = config.basic.problem_batch
|
||||||
|
|
||||||
self.pop_size = config.neat.population.pop_size
|
self.pop_size = config.neat.population.pop_size
|
||||||
self.init_N = config.basic.init_maximum_nodes
|
|
||||||
|
|
||||||
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||||
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||||
@@ -85,6 +84,7 @@ class FunctionFactory:
|
|||||||
initialize_genomes,
|
initialize_genomes,
|
||||||
pop_size=self.pop_size,
|
pop_size=self.pop_size,
|
||||||
N=self.init_N,
|
N=self.init_N,
|
||||||
|
C=self.init_C,
|
||||||
num_inputs=self.num_inputs,
|
num_inputs=self.num_inputs,
|
||||||
num_outputs=self.num_outputs,
|
num_outputs=self.num_outputs,
|
||||||
default_bias=self.bias_mean,
|
default_bias=self.bias_mean,
|
||||||
@@ -107,24 +107,24 @@ class FunctionFactory:
|
|||||||
self.create_crossover_with_args()
|
self.create_crossover_with_args()
|
||||||
self.create_topological_sort_with_args()
|
self.create_topological_sort_with_args()
|
||||||
self.create_single_forward_with_args()
|
self.create_single_forward_with_args()
|
||||||
|
#
|
||||||
n = self.init_N
|
# n, c = self.init_N, self.init_C
|
||||||
print("start precompile")
|
# print("start precompile")
|
||||||
for _ in range(self.precompile_times):
|
# for _ in range(self.precompile_times):
|
||||||
self.compile_mutate(n)
|
# self.compile_mutate(n)
|
||||||
self.compile_distance(n)
|
# self.compile_distance(n)
|
||||||
self.compile_crossover(n)
|
# self.compile_crossover(n)
|
||||||
self.compile_topological_sort_batch(n)
|
# self.compile_topological_sort_batch(n)
|
||||||
self.compile_pop_batch_forward(n)
|
# self.compile_pop_batch_forward(n)
|
||||||
n = int(self.expand_coe * n)
|
# n = int(self.expand_coe * n)
|
||||||
|
#
|
||||||
# precompile other functions used in jax
|
# # precompile other functions used in jax
|
||||||
key = jax.random.PRNGKey(0)
|
# key = jax.random.PRNGKey(0)
|
||||||
_ = jax.random.split(key, 3)
|
# _ = jax.random.split(key, 3)
|
||||||
_ = jax.random.split(key, self.pop_size * 2)
|
# _ = jax.random.split(key, self.pop_size * 2)
|
||||||
_ = jax.random.split(key, self.pop_size)
|
# _ = jax.random.split(key, self.pop_size)
|
||||||
|
#
|
||||||
print("end precompile")
|
# print("end precompile")
|
||||||
|
|
||||||
def create_mutate_with_args(self):
|
def create_mutate_with_args(self):
|
||||||
func = partial(
|
func = partial(
|
||||||
@@ -161,20 +161,20 @@ class FunctionFactory:
|
|||||||
)
|
)
|
||||||
self.mutate_with_args = func
|
self.mutate_with_args = func
|
||||||
|
|
||||||
def compile_mutate(self, n):
|
def compile_mutate(self, n, c):
|
||||||
func = self.mutate_with_args
|
func = self.mutate_with_args
|
||||||
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
connections_lower = np.zeros((self.pop_size, c, 4))
|
||||||
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||||
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
|
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
|
||||||
connections_lower, new_node_key_lower).compile()
|
connections_lower, new_node_key_lower).compile()
|
||||||
self.compiled_function[('mutate', n)] = batched_mutate_func
|
self.compiled_function[('mutate', n, c)] = batched_mutate_func
|
||||||
|
|
||||||
def create_mutate(self, n):
|
def create_mutate(self, n, c):
|
||||||
key = ('mutate', n)
|
key = ('mutate', n, c)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_mutate(n)
|
self.compile_mutate(n, c)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
def debug_mutate(*args):
|
def debug_mutate(*args):
|
||||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||||
@@ -192,28 +192,28 @@ class FunctionFactory:
|
|||||||
)
|
)
|
||||||
self.distance_with_args = func
|
self.distance_with_args = func
|
||||||
|
|
||||||
def compile_distance(self, n):
|
def compile_distance(self, n, c):
|
||||||
func = self.distance_with_args
|
func = self.distance_with_args
|
||||||
o2o_nodes1_lower = np.zeros((n, 5))
|
o2o_nodes1_lower = np.zeros((n, 5))
|
||||||
o2o_connections1_lower = np.zeros((2, n, n))
|
o2o_connections1_lower = np.zeros((c, 4))
|
||||||
o2o_nodes2_lower = np.zeros((n, 5))
|
o2o_nodes2_lower = np.zeros((n, 5))
|
||||||
o2o_connections2_lower = np.zeros((2, n, n))
|
o2o_connections2_lower = np.zeros((c, 4))
|
||||||
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||||
o2o_nodes2_lower, o2o_connections2_lower).compile()
|
o2o_nodes2_lower, o2o_connections2_lower).compile()
|
||||||
|
|
||||||
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
|
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||||
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||||
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||||
o2m_nodes2_lower,
|
o2m_nodes2_lower,
|
||||||
o2m_connections2_lower).compile()
|
o2m_connections2_lower).compile()
|
||||||
|
|
||||||
self.compiled_function[('o2o_distance', n)] = o2o_distance
|
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
|
||||||
self.compiled_function[('o2m_distance', n)] = o2m_distance
|
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
|
||||||
|
|
||||||
def create_distance(self, n):
|
def create_distance(self, n, c):
|
||||||
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
|
||||||
if key1 not in self.compiled_function:
|
if key1 not in self.compiled_function:
|
||||||
self.compile_distance(n)
|
self.compile_distance(n, c)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
|
|
||||||
def debug_o2o_distance(*args):
|
def debug_o2o_distance(*args):
|
||||||
@@ -229,21 +229,21 @@ class FunctionFactory:
|
|||||||
def create_crossover_with_args(self):
|
def create_crossover_with_args(self):
|
||||||
self.crossover_with_args = crossover
|
self.crossover_with_args = crossover
|
||||||
|
|
||||||
def compile_crossover(self, n):
|
def compile_crossover(self, n, c):
|
||||||
func = self.crossover_with_args
|
func = self.crossover_with_args
|
||||||
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||||
nodes1_lower = np.zeros((self.pop_size, n, 5))
|
nodes1_lower = np.zeros((self.pop_size, n, 5))
|
||||||
connections1_lower = np.zeros((self.pop_size, 2, n, n))
|
connections1_lower = np.zeros((self.pop_size, c, 4))
|
||||||
nodes2_lower = np.zeros((self.pop_size, n, 5))
|
nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||||
connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||||
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||||
nodes2_lower, connections2_lower).compile()
|
nodes2_lower, connections2_lower).compile()
|
||||||
self.compiled_function[('crossover', n)] = func
|
self.compiled_function[('crossover', n, c)] = func
|
||||||
|
|
||||||
def create_crossover(self, n):
|
def create_crossover(self, n, c):
|
||||||
key = ('crossover', n)
|
key = ('crossover', n, c)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_crossover(n)
|
self.compile_crossover(n, c)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
|
|
||||||
def debug_crossover(*args):
|
def debug_crossover(*args):
|
||||||
@@ -365,15 +365,17 @@ class FunctionFactory:
|
|||||||
else:
|
else:
|
||||||
return self.compiled_function[key]
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
|
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||||
n = pop_nodes.shape[1]
|
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||||
|
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||||
|
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||||
ts = self.create_topological_sort_batch(n)
|
ts = self.create_topological_sort_batch(n)
|
||||||
pop_cal_seqs = ts(pop_nodes, pop_connections)
|
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||||
|
|
||||||
forward_func = self.create_pop_batch_forward(n)
|
forward_func = self.create_pop_batch_forward(n)
|
||||||
|
|
||||||
def debug_forward(inputs):
|
def debug_forward(inputs):
|
||||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
|
||||||
|
|
||||||
return debug_forward
|
return debug_forward
|
||||||
|
|
||||||
@@ -387,3 +389,23 @@ class FunctionFactory:
|
|||||||
return forward_func(inputs, cal_seqs, nodes, connections)
|
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||||
|
|
||||||
return debug_forward
|
return debug_forward
|
||||||
|
|
||||||
|
def compile_batch_unflatten_connections(self, n, c):
|
||||||
|
func = unflatten_connections
|
||||||
|
func = vmap(func)
|
||||||
|
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||||
|
pop_connections_lower = np.zeros((self.pop_size, c, 4))
|
||||||
|
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||||
|
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||||
|
|
||||||
|
def create_batch_unflatten_connections(self, n, c):
|
||||||
|
key = ('batch_unflatten_connections', n, c)
|
||||||
|
if key not in self.compiled_function:
|
||||||
|
self.compile_batch_unflatten_connections(n, c)
|
||||||
|
if self.debug:
|
||||||
|
def debug_batch_unflatten_connections(*args):
|
||||||
|
return self.compiled_function[key](*args).block_until_ready()
|
||||||
|
|
||||||
|
return debug_batch_unflatten_connections
|
||||||
|
else:
|
||||||
|
return self.compiled_function[key]
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from .genome import expand, expand_single, pop_analysis, initialize_genomes
|
from .genome import expand, expand_single, initialize_genomes
|
||||||
from .forward import create_forward_function, forward_single
|
from .forward import forward_single
|
||||||
from .activations import act_name2key
|
from .activations import act_name2key
|
||||||
from .aggregations import agg_name2key
|
from .aggregations import agg_name2key
|
||||||
from .crossover import crossover
|
from .crossover import crossover
|
||||||
from .mutate import mutate
|
from .mutate import mutate
|
||||||
from .distance import distance
|
from .distance import distance
|
||||||
from .graph import topological_sort
|
from .graph import topological_sort
|
||||||
|
from .utils import unflatten_connections
|
||||||
@@ -23,8 +23,8 @@ def sin_act(z):
|
|||||||
|
|
||||||
@jit
|
@jit
|
||||||
def gauss_act(z):
|
def gauss_act(z):
|
||||||
z = jnp.clip(z, -3.4, 3.4)
|
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||||
return jnp.exp(-5 * z ** 2)
|
return jnp.exp(-z ** 2)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
|
|||||||
@@ -7,16 +7,16 @@ from jax import numpy as jnp
|
|||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||||
-> Tuple[Array, Array]:
|
-> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
use genome1 and genome2 to generate a new genome
|
use genome1 and genome2 to generate a new genome
|
||||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
:param randkey:
|
:param randkey:
|
||||||
:param nodes1:
|
:param nodes1:
|
||||||
:param connections1:
|
:param cons1:
|
||||||
:param nodes2:
|
:param nodes2:
|
||||||
:param connections2:
|
:param cons2:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
randkey_1, randkey_2 = jax.random.split(randkey)
|
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||||
@@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
|||||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
|
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||||
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
|
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||||
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
|
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
|
||||||
|
|
||||||
return new_nodes, new_cons
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=['gene_type'])
|
# @partial(jit, static_argnames=['gene_type'])
|
||||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||||
"""
|
"""
|
||||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||||
@@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
|||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||||
"""
|
"""
|
||||||
crossover two genes
|
crossover two genes
|
||||||
|
|||||||
0
algorithms/neat/genome/debug/__init__.py
Normal file
0
algorithms/neat/genome/debug/__init__.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def check_array_valid(nodes, cons, input_keys, output_keys):
|
||||||
|
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
|
||||||
|
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
|
||||||
|
|
||||||
|
|
||||||
|
def array2object(nodes, cons, input_keys, output_keys):
|
||||||
|
"""
|
||||||
|
Convert a genome from array to dict.
|
||||||
|
:param nodes: (N, 5)
|
||||||
|
:param cons: (C, 4)
|
||||||
|
:param output_keys:
|
||||||
|
:param input_keys:
|
||||||
|
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
||||||
|
"""
|
||||||
|
# update nodes_dict
|
||||||
|
nodes_dict = {}
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
if np.isnan(node[0]):
|
||||||
|
continue
|
||||||
|
key = int(node[0])
|
||||||
|
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||||
|
|
||||||
|
if key in input_keys:
|
||||||
|
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
|
||||||
|
nodes_dict[key] = (None,) * 4
|
||||||
|
else:
|
||||||
|
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
|
||||||
|
bias = node[1]
|
||||||
|
response = node[2]
|
||||||
|
act = node[3]
|
||||||
|
agg = node[4]
|
||||||
|
nodes_dict[key] = (bias, response, act, agg)
|
||||||
|
|
||||||
|
# check nodes_dict
|
||||||
|
for i in input_keys:
|
||||||
|
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||||
|
|
||||||
|
for o in output_keys:
|
||||||
|
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||||
|
|
||||||
|
# update connections
|
||||||
|
cons_dict = {}
|
||||||
|
for i, con in enumerate(cons):
|
||||||
|
if np.all(np.isnan(con)):
|
||||||
|
pass
|
||||||
|
elif np.all(~np.isnan(con)):
|
||||||
|
i_key = int(con[0])
|
||||||
|
o_key = int(con[1])
|
||||||
|
if (i_key, o_key) in cons_dict:
|
||||||
|
assert False, f"Duplicate connection: {(i_key, o_key)}!"
|
||||||
|
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
||||||
|
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
||||||
|
weight = con[2]
|
||||||
|
enabled = (con[3] == 1)
|
||||||
|
cons_dict[(i_key, o_key)] = (weight, enabled)
|
||||||
|
else:
|
||||||
|
assert False, f"Connection {i} must has all None or all non-None!"
|
||||||
|
|
||||||
|
return nodes_dict, cons_dict
|
||||||
|
|
||||||
|
|
||||||
|
def is_DAG(edges):
|
||||||
|
all_nodes = set()
|
||||||
|
for a, b in edges:
|
||||||
|
if a == b: # cycle
|
||||||
|
return False
|
||||||
|
all_nodes.union({a, b})
|
||||||
|
|
||||||
|
for node in all_nodes:
|
||||||
|
visited = {n: False for n in all_nodes}
|
||||||
|
def dfs(n):
|
||||||
|
if visited[n]:
|
||||||
|
return False
|
||||||
|
visited[n] = True
|
||||||
|
for a, b in edges:
|
||||||
|
if a == n:
|
||||||
|
if not dfs(b):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not dfs(node):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
from jax import jit, vmap, Array
|
from jax import jit, vmap, Array
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
from .utils import EMPTY_NODE, EMPTY_CON
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
|
||||||
compatibility_coe: float = 0.5) -> Array:
|
compatibility_coe: float = 0.5) -> Array:
|
||||||
"""
|
"""
|
||||||
Calculate the distance between two genomes.
|
Calculate the distance between two genomes.
|
||||||
@@ -15,10 +15,6 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
|
|||||||
|
|
||||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||||
|
|
||||||
# refactor connections
|
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
|
||||||
cons1 = flatten_connections(keys1, connections1)
|
|
||||||
cons2 = flatten_connections(keys2, connections2)
|
|
||||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||||
return nd + cd
|
return nd + cd
|
||||||
|
|
||||||
@@ -35,9 +31,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
|||||||
nodes = nodes[sorted_indices]
|
nodes = nodes[sorted_indices]
|
||||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||||
nan_mask = jnp.isnan(nodes[:, 0])
|
|
||||||
|
|
||||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
|
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||||
|
|
||||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||||
nd = batch_homologous_node_distance(fr, sr)
|
nd = batch_homologous_node_distance(fr, sr)
|
||||||
@@ -50,8 +45,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
|||||||
|
|
||||||
@jit
|
@jit
|
||||||
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
|
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
|
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||||
|
|
||||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||||
@@ -62,7 +57,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
|||||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||||
|
|
||||||
# both genome has such connection
|
# both genome has such connection
|
||||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
|
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||||
|
|
||||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||||
cd = batch_homologous_connection_distance(fr, sr)
|
cd = batch_homologous_connection_distance(fr, sr)
|
||||||
|
|||||||
@@ -1,51 +1,12 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import Array, numpy as jnp
|
from jax import Array, numpy as jnp
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from .aggregations import agg
|
from .aggregations import agg
|
||||||
from .activations import act
|
from .activations import act
|
||||||
from .graph import topological_sort, batch_topological_sort
|
|
||||||
from .utils import I_INT
|
from .utils import I_INT
|
||||||
|
|
||||||
|
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||||
def create_forward_function(nodes: NDArray, connections: NDArray,
|
|
||||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
|
|
||||||
"""
|
|
||||||
create forward function for different situations
|
|
||||||
|
|
||||||
:param nodes: shape (N, 5) or (pop_size, N, 5)
|
|
||||||
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
|
|
||||||
:param N:
|
|
||||||
:param input_idx:
|
|
||||||
:param output_idx:
|
|
||||||
:param batch: using batch or not
|
|
||||||
:param debug: debug mode
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if nodes.ndim == 2: # single genome
|
|
||||||
cal_seqs = topological_sort(nodes, connections)
|
|
||||||
if not batch:
|
|
||||||
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
|
|
||||||
cal_seqs, nodes, connections)
|
|
||||||
else:
|
|
||||||
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
|
|
||||||
cal_seqs, nodes, connections)
|
|
||||||
elif nodes.ndim == 3: # pop genome
|
|
||||||
pop_cal_seqs = batch_topological_sort(nodes, connections)
|
|
||||||
if not batch:
|
|
||||||
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
|
|
||||||
pop_cal_seqs, nodes, connections)
|
|
||||||
else:
|
|
||||||
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
|
|
||||||
pop_cal_seqs, nodes, connections)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||||
input_idx: Array, output_idx: Array) -> Array:
|
input_idx: Array, output_idx: Array) -> Array:
|
||||||
@@ -84,66 +45,3 @@ def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Ar
|
|||||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||||
|
|
||||||
return vals[output_idx]
|
return vals[output_idx]
|
||||||
|
|
||||||
|
|
||||||
# @partial(jit, static_argnames=['N'])
|
|
||||||
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
|
||||||
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
|
||||||
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
|
||||||
# """
|
|
||||||
# jax forward for batch_inputs shaped (batch_size, input_num)
|
|
||||||
# nodes, connections are single genome
|
|
||||||
#
|
|
||||||
# :argument batch_inputs: (batch_size, input_num)
|
|
||||||
# :argument N: int
|
|
||||||
# :argument input_idx: (input_num, )
|
|
||||||
# :argument output_idx: (output_num, )
|
|
||||||
# :argument cal_seqs: (N, )
|
|
||||||
# :argument nodes: (N, 5)
|
|
||||||
# :argument connections: (2, N, N)
|
|
||||||
#
|
|
||||||
# :return (batch_size, output_num)
|
|
||||||
# """
|
|
||||||
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# @partial(jit, static_argnames=['N'])
|
|
||||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
|
||||||
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
|
||||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
|
||||||
# """
|
|
||||||
# jax forward for single input shaped (input_num, )
|
|
||||||
# pop_nodes, pop_connections are population of genomes
|
|
||||||
#
|
|
||||||
# :argument inputs: (input_num, )
|
|
||||||
# :argument N: int
|
|
||||||
# :argument input_idx: (input_num, )
|
|
||||||
# :argument output_idx: (output_num, )
|
|
||||||
# :argument pop_cal_seqs: (pop_size, N)
|
|
||||||
# :argument pop_nodes: (pop_size, N, 5)
|
|
||||||
# :argument pop_connections: (pop_size, 2, N, N)
|
|
||||||
#
|
|
||||||
# :return (pop_size, output_num)
|
|
||||||
# """
|
|
||||||
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# @partial(jit, static_argnames=['N'])
|
|
||||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
|
||||||
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
|
||||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
|
||||||
# """
|
|
||||||
# jax forward for batch input shaped (batch, input_num)
|
|
||||||
# pop_nodes, pop_connections are population of genomes
|
|
||||||
#
|
|
||||||
# :argument batch_inputs: (batch_size, input_num)
|
|
||||||
# :argument N: int
|
|
||||||
# :argument input_idx: (input_num, )
|
|
||||||
# :argument output_idx: (output_num, )
|
|
||||||
# :argument pop_cal_seqs: (pop_size, N)
|
|
||||||
# :argument pop_nodes: (pop_size, N, 5)
|
|
||||||
# :argument pop_connections: (pop_size, 2, N, N)
|
|
||||||
#
|
|
||||||
# :return (pop_size, batch_size, output_num)
|
|
||||||
# """
|
|
||||||
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from jax import numpy as jnp
|
|||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import Array
|
from jax import Array
|
||||||
|
|
||||||
from .utils import fetch_first, EMPTY_NODE
|
from .utils import fetch_first
|
||||||
|
|
||||||
|
|
||||||
def initialize_genomes(pop_size: int,
|
def initialize_genomes(pop_size: int,
|
||||||
@@ -124,79 +124,6 @@ def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tupl
|
|||||||
return new_nodes, new_cons
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
|
|
||||||
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
|
|
||||||
"""
|
|
||||||
Convert a genome from array to dict.
|
|
||||||
:param nodes: (N, 5)
|
|
||||||
:param cons: (C, 4)
|
|
||||||
:param output_keys:
|
|
||||||
:param input_keys:
|
|
||||||
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
|
||||||
"""
|
|
||||||
# update nodes_dict
|
|
||||||
try:
|
|
||||||
nodes_dict = {}
|
|
||||||
for i, node in enumerate(nodes):
|
|
||||||
if np.isnan(node[0]):
|
|
||||||
continue
|
|
||||||
key = int(node[0])
|
|
||||||
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
|
||||||
|
|
||||||
bias = node[1] if not np.isnan(node[1]) else None
|
|
||||||
response = node[2] if not np.isnan(node[2]) else None
|
|
||||||
act = node[3] if not np.isnan(node[3]) else None
|
|
||||||
agg = node[4] if not np.isnan(node[4]) else None
|
|
||||||
nodes_dict[key] = (bias, response, act, agg)
|
|
||||||
|
|
||||||
# check nodes_dict
|
|
||||||
for i in input_keys:
|
|
||||||
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
|
||||||
bias, response, act, agg = nodes_dict[i]
|
|
||||||
assert bias is None and response is None and act is None and agg is None, \
|
|
||||||
f"Input node {i} must has None bias, response, act, or agg!"
|
|
||||||
|
|
||||||
for o in output_keys:
|
|
||||||
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
|
||||||
|
|
||||||
for k, v in nodes_dict.items():
|
|
||||||
if k not in input_keys:
|
|
||||||
bias, response, act, agg = v
|
|
||||||
assert bias is not None and response is not None and act is not None and agg is not None, \
|
|
||||||
f"Normal node {k} must has non-None bias, response, act, or agg!"
|
|
||||||
|
|
||||||
# update connections
|
|
||||||
cons_dict = {}
|
|
||||||
for i, con in enumerate(cons):
|
|
||||||
if np.isnan(con[0]):
|
|
||||||
continue
|
|
||||||
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
|
|
||||||
i_key = int(con[0])
|
|
||||||
o_key = int(con[1])
|
|
||||||
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
|
||||||
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
|
||||||
key = (i_key, o_key)
|
|
||||||
weight = con[2] if not np.isnan(con[2]) else None
|
|
||||||
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
|
|
||||||
assert weight is not None, f"Connection {key} must has non-None weight!"
|
|
||||||
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
|
||||||
|
|
||||||
cons_dict[key] = (weight, enabled)
|
|
||||||
|
|
||||||
return nodes_dict, cons_dict
|
|
||||||
except AssertionError:
|
|
||||||
print(nodes)
|
|
||||||
print(cons)
|
|
||||||
raise AssertionError
|
|
||||||
|
|
||||||
|
|
||||||
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
|
|
||||||
res = []
|
|
||||||
for nodes, cons in zip(pop_nodes, pop_cons):
|
|
||||||
res.append(analysis(nodes, cons, input_keys, output_keys))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def count(nodes, cons):
|
def count(nodes, cons):
|
||||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||||
@@ -231,7 +158,7 @@ def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Arra
|
|||||||
"""
|
"""
|
||||||
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
||||||
"""
|
"""
|
||||||
nodes = nodes.at[idx].set(EMPTY_NODE)
|
nodes = nodes.at[idx].set(np.nan)
|
||||||
return nodes, cons
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
@@ -243,7 +170,7 @@ def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
|
|||||||
"""
|
"""
|
||||||
con_keys = cons[:, 0]
|
con_keys = cons[:, 0]
|
||||||
idx = fetch_first(jnp.isnan(con_keys))
|
idx = fetch_first(jnp.isnan(con_keys))
|
||||||
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
|
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
|
|||||||
@@ -6,11 +6,9 @@ import numpy as np
|
|||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import jit, vmap, Array
|
from jax import jit, vmap, Array
|
||||||
|
|
||||||
from .utils import fetch_random, fetch_first, I_INT
|
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||||
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
|
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||||
from .graph import check_cycles
|
from .graph import check_cycles
|
||||||
from .activations import act_name2key
|
|
||||||
from .aggregations import agg_name2key
|
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=('single_structure_mutate',))
|
@partial(jit, static_argnames=('single_structure_mutate',))
|
||||||
@@ -89,7 +87,7 @@ def mutate(rand_key: Array,
|
|||||||
return n, c
|
return n, c
|
||||||
|
|
||||||
def m_add_node(rk, n, c):
|
def m_add_node(rk, n, c):
|
||||||
return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default)
|
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
|
||||||
|
|
||||||
def m_delete_node(rk, n, c):
|
def m_delete_node(rk, n, c):
|
||||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||||
@@ -153,7 +151,7 @@ def mutate(rand_key: Array,
|
|||||||
@jit
|
@jit
|
||||||
def mutate_values(rand_key: Array,
|
def mutate_values(rand_key: Array,
|
||||||
nodes: Array,
|
nodes: Array,
|
||||||
connections: Array,
|
cons: Array,
|
||||||
bias_mean: float = 0,
|
bias_mean: float = 0,
|
||||||
bias_std: float = 1,
|
bias_std: float = 1,
|
||||||
bias_mutate_strength: float = 0.5,
|
bias_mutate_strength: float = 0.5,
|
||||||
@@ -180,7 +178,7 @@ def mutate_values(rand_key: Array,
|
|||||||
Args:
|
Args:
|
||||||
rand_key: A random key for generating random values.
|
rand_key: A random key for generating random values.
|
||||||
nodes: A 2D array representing nodes.
|
nodes: A 2D array representing nodes.
|
||||||
connections: A 3D array representing connections.
|
cons: A 3D array representing connections.
|
||||||
bias_mean: Mean of the bias values.
|
bias_mean: Mean of the bias values.
|
||||||
bias_std: Standard deviation of the bias values.
|
bias_std: Standard deviation of the bias values.
|
||||||
bias_mutate_strength: Strength of the bias mutation.
|
bias_mutate_strength: Strength of the bias mutation.
|
||||||
@@ -211,24 +209,23 @@ def mutate_values(rand_key: Array,
|
|||||||
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
|
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
|
||||||
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
|
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
|
||||||
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
||||||
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
|
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
|
||||||
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
||||||
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
|
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
|
||||||
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
|
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
|
||||||
|
|
||||||
# refactor enabled
|
# mutate enabled
|
||||||
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
|
r = jax.random.uniform(rand_key, cons[:, 3].shape)
|
||||||
enabled_new = connections[1, :, :] == 1
|
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
|
||||||
enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new)
|
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
|
||||||
enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan)
|
|
||||||
|
|
||||||
nodes = nodes.at[:, 1].set(bias_new)
|
nodes = nodes.at[:, 1].set(bias_new)
|
||||||
nodes = nodes.at[:, 2].set(response_new)
|
nodes = nodes.at[:, 2].set(response_new)
|
||||||
nodes = nodes.at[:, 3].set(act_new)
|
nodes = nodes.at[:, 3].set(act_new)
|
||||||
nodes = nodes.at[:, 4].set(agg_new)
|
nodes = nodes.at[:, 4].set(agg_new)
|
||||||
connections = connections.at[0, :, :].set(weight_new)
|
cons = cons.at[:, 2].set(weight_new)
|
||||||
connections = connections.at[1, :, :].set(enabled_new)
|
cons = cons.at[:, 3].set(enabled_new)
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
@@ -288,7 +285,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
|||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
|
||||||
default_bias: float = 0, default_response: float = 1,
|
default_bias: float = 0, default_response: float = 1,
|
||||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -296,7 +293,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
|||||||
:param rand_key:
|
:param rand_key:
|
||||||
:param new_node_key:
|
:param new_node_key:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:param connections:
|
:param cons:
|
||||||
:param default_bias:
|
:param default_bias:
|
||||||
:param default_response:
|
:param default_response:
|
||||||
:param default_act:
|
:param default_act:
|
||||||
@@ -304,44 +301,42 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# randomly choose a connection
|
# randomly choose a connection
|
||||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||||
|
|
||||||
def nothing():
|
def nothing(): # there is no connection to split
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
def successful_add_node():
|
def successful_add_node():
|
||||||
# disable the connection
|
# disable the connection
|
||||||
new_nodes, new_connections = nodes, connections
|
new_nodes, new_cons = nodes, cons
|
||||||
new_connections = new_connections.at[1, from_idx, to_idx].set(False)
|
new_cons = new_cons.at[idx, 3].set(False)
|
||||||
|
|
||||||
# add a new node
|
# add a new node
|
||||||
new_nodes, new_connections = \
|
new_nodes, new_cons = \
|
||||||
add_node(new_node_key, new_nodes, new_connections,
|
add_node(new_nodes, new_cons, new_node_key,
|
||||||
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||||
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
|
|
||||||
|
|
||||||
# add two new connections
|
# add two new connections
|
||||||
weight = new_connections[0, from_idx, to_idx]
|
w = new_cons[idx, 2]
|
||||||
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
|
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
|
||||||
new_nodes, new_connections, weight=1., enabled=True)
|
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
|
||||||
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
|
return new_nodes, new_cons
|
||||||
new_nodes, new_connections, weight=weight, enabled=True)
|
|
||||||
return new_nodes, new_connections
|
|
||||||
|
|
||||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
|
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||||
|
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Need we really need to delete a node?
|
||||||
@jit
|
@jit
|
||||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||||
:param rand_key:
|
:param rand_key:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:param connections:
|
:param cons:
|
||||||
:param input_keys:
|
:param input_keys:
|
||||||
:param output_keys:
|
:param output_keys:
|
||||||
:return:
|
:return:
|
||||||
@@ -351,83 +346,86 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
|||||||
allow_input_keys=False, allow_output_keys=False)
|
allow_input_keys=False, allow_output_keys=False)
|
||||||
|
|
||||||
def nothing():
|
def nothing():
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
def successful_delete_node():
|
def successful_delete_node():
|
||||||
# delete the node
|
# delete the node
|
||||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
|
||||||
|
|
||||||
# delete connections
|
# delete all connections
|
||||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
|
||||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
jnp.nan, aux_cons)
|
||||||
|
|
||||||
return aux_nodes, aux_connections
|
return aux_nodes, aux_cons
|
||||||
|
|
||||||
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||||
|
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||||
cycles are not allowed.
|
cycles are not allowed.
|
||||||
:param rand_key:
|
:param rand_key:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:param connections:
|
:param cons:
|
||||||
:param input_keys:
|
:param input_keys:
|
||||||
:param output_keys:
|
:param output_keys:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# randomly choose two nodes
|
# randomly choose two nodes
|
||||||
k1, k2 = jax.random.split(rand_key, num=2)
|
k1, k2 = jax.random.split(rand_key, num=2)
|
||||||
from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||||
allow_input_keys=True, allow_output_keys=True)
|
allow_input_keys=True, allow_output_keys=True)
|
||||||
to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||||
allow_input_keys=False, allow_output_keys=True)
|
allow_input_keys=False, allow_output_keys=True)
|
||||||
|
|
||||||
|
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||||
|
|
||||||
def successful():
|
def successful():
|
||||||
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections)
|
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
|
||||||
return new_nodes, new_connections
|
return new_nodes, new_cons
|
||||||
|
|
||||||
def already_exist():
|
def already_exist():
|
||||||
new_connections = connections.at[1, from_idx, to_idx].set(True)
|
new_cons = cons.at[con_idx, 3].set(True)
|
||||||
return nodes, new_connections
|
return nodes, new_cons
|
||||||
|
|
||||||
def cycle():
|
def cycle():
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx])
|
is_already_exist = con_idx != I_INT
|
||||||
is_cycle = check_cycles(nodes, connections, from_idx, to_idx)
|
unflattened = unflatten_connections(nodes, cons)
|
||||||
|
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
|
||||||
|
|
||||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||||
nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful])
|
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||||
"""
|
"""
|
||||||
Randomly delete a connection.
|
Randomly delete a connection.
|
||||||
:param rand_key:
|
:param rand_key:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:param connections:
|
:param cons:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# randomly choose a connection
|
# randomly choose a connection
|
||||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||||
|
|
||||||
def nothing():
|
def nothing():
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
def successfully_delete_connection():
|
def successfully_delete_connection():
|
||||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
return delete_connection_by_idx(nodes, cons, idx)
|
||||||
|
|
||||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
|
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||||
|
|
||||||
return nodes, connections
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||||
@@ -460,31 +458,20 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
|||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||||
"""
|
"""
|
||||||
Randomly choose a connection key from the given connections.
|
Randomly choose a connection key from the given connections.
|
||||||
:param rand_key:
|
:param rand_key:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:param connection:
|
:param cons:
|
||||||
:return: from_key, to_key, from_idx, to_idx
|
:return: i_key, o_key, idx
|
||||||
"""
|
"""
|
||||||
|
|
||||||
k1, k2 = jax.random.split(rand_key, num=2)
|
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
|
||||||
|
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
|
||||||
|
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
|
||||||
|
|
||||||
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
|
return i_key, o_key, idx
|
||||||
|
|
||||||
def nothing():
|
|
||||||
return jnp.nan, jnp.nan, I_INT, I_INT
|
|
||||||
|
|
||||||
def has_connection():
|
|
||||||
f_idx = fetch_random(k1, has_connections_row)
|
|
||||||
col = connection[0, f_idx, :]
|
|
||||||
t_idx = fetch_random(k2, ~jnp.isnan(col))
|
|
||||||
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
|
|
||||||
return f_key, t_key, f_idx, t_idx
|
|
||||||
|
|
||||||
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
|
|
||||||
return from_key, to_key, from_idx, to_idx
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
|
|||||||
@@ -3,84 +3,38 @@ from typing import Tuple
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp, Array
|
from jax import numpy as jnp, Array
|
||||||
from jax import jit
|
from jax import jit, vmap
|
||||||
|
|
||||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def flatten_connections(keys, connections):
|
def unflatten_connections(nodes, cons):
|
||||||
"""
|
"""
|
||||||
flatten the (2, N, N) connections to (N * N, 4)
|
transform the (C, 4) connections to (2, N, N)
|
||||||
:param keys:
|
|
||||||
:param connections:
|
|
||||||
:return:
|
|
||||||
the first two columns are the index of the node
|
|
||||||
the 3rd column is the weight, and the 4th column is the enabled status
|
|
||||||
"""
|
|
||||||
indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij')
|
|
||||||
indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
|
|
||||||
|
|
||||||
# make (2, N, N) to (N, N, 2)
|
|
||||||
con = jnp.transpose(connections, (1, 2, 0))
|
|
||||||
# make (N, N, 2) to (N * N, 2)
|
|
||||||
con = jnp.reshape(con, (-1, 2))
|
|
||||||
|
|
||||||
con = jnp.concatenate((indices, con), axis=1)
|
|
||||||
return con
|
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=['N'])
|
|
||||||
def unflatten_connections(N, cons):
|
|
||||||
"""
|
|
||||||
restore the (N * N, 4) connections to (2, N, N)
|
|
||||||
:param N:
|
|
||||||
:param cons:
|
:param cons:
|
||||||
|
:param nodes:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
cons = cons[:, 2:] # remove the indices
|
N = nodes.shape[0]
|
||||||
unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0)
|
node_keys = nodes[:, 0]
|
||||||
return unflatten_cons
|
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||||
|
i_idxs = key_to_indices(i_keys, node_keys)
|
||||||
|
o_idxs = key_to_indices(o_keys, node_keys)
|
||||||
|
res = jnp.full((2, N, N), jnp.nan)
|
||||||
|
|
||||||
|
# Is interesting that jax use clip when attach data in array
|
||||||
|
# however, it will do nothing set values in an array
|
||||||
|
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||||
|
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@partial(vmap, in_axes=(0, None))
|
||||||
def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]:
|
def key_to_indices(key, keys):
|
||||||
"""
|
return fetch_first(key == keys)
|
||||||
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
|
|
||||||
intersection mask, and union mask.
|
|
||||||
|
|
||||||
:param ar1: JAX array of shape (N, M)
|
|
||||||
First input array. Should have the same shape as ar2.
|
|
||||||
:param ar2: JAX array of shape (N, M)
|
|
||||||
Second input array. Should have the same shape as ar1.
|
|
||||||
:return: tuple of 3 arrays
|
|
||||||
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
|
|
||||||
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
|
|
||||||
in the sorted concatenation.
|
|
||||||
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
|
|
||||||
in the sorted concatenation.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
a = jnp.array([[1, 2], [3, 4], [5, 6]])
|
|
||||||
b = jnp.array([[1, 2], [7, 8], [9, 10]])
|
|
||||||
|
|
||||||
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
|
|
||||||
|
|
||||||
sorted_indices -> array([0, 1, 2, 3, 4, 5])
|
|
||||||
intersect_mask -> array([True, False, False, False, False, False])
|
|
||||||
union_mask -> array([False, True, True, True, True, True])
|
|
||||||
"""
|
|
||||||
ar = jnp.concatenate((ar1, ar2), axis=0)
|
|
||||||
sorted_indices = jnp.lexsort(ar.T[::-1])
|
|
||||||
aux = ar[sorted_indices]
|
|
||||||
aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0)
|
|
||||||
nan_mask = jnp.any(jnp.isnan(aux), axis=1)
|
|
||||||
|
|
||||||
fr, sr = aux[:-1], aux[1:] # first row, second row
|
|
||||||
intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1]
|
|
||||||
union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1]
|
|
||||||
return sorted_indices, intersect_mask, union_mask
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .species import SpeciesController
|
|||||||
from .genome import expand, expand_single
|
from .genome import expand, expand_single
|
||||||
from .function_factory import FunctionFactory
|
from .function_factory import FunctionFactory
|
||||||
from .genome.genome import count
|
from .genome.genome import count
|
||||||
|
from .genome.debug.tools import check_array_valid
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""
|
"""
|
||||||
@@ -23,6 +24,7 @@ class Pipeline:
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.N = config.basic.init_maximum_nodes
|
self.N = config.basic.init_maximum_nodes
|
||||||
|
self.C = config.basic.init_maximum_connections
|
||||||
self.expand_coe = config.basic.expands_coe
|
self.expand_coe = config.basic.expands_coe
|
||||||
self.pop_size = config.neat.population.pop_size
|
self.pop_size = config.neat.population.pop_size
|
||||||
|
|
||||||
@@ -57,6 +59,8 @@ class Pipeline:
|
|||||||
|
|
||||||
self.update_next_generation(winner_part, loser_part, elite_mask)
|
self.update_next_generation(winner_part, loser_part, elite_mask)
|
||||||
|
|
||||||
|
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||||
|
|
||||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
||||||
self.o2o_distance, self.o2m_distance)
|
self.o2o_distance, self.o2m_distance)
|
||||||
|
|
||||||
@@ -105,16 +109,25 @@ class Pipeline:
|
|||||||
|
|
||||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
|
|
||||||
|
# for i in range(self.pop_size):
|
||||||
|
# n, c = np.array(npn[i]), np.array(npc[i])
|
||||||
|
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||||
|
|
||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
|
|
||||||
|
# for i in range(self.pop_size):
|
||||||
|
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
|
||||||
|
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||||
|
|
||||||
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
|
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
|
||||||
self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc)
|
self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
"""
|
"""
|
||||||
@@ -128,20 +141,38 @@ class Pipeline:
|
|||||||
max_node_size = np.max(pop_node_sizes)
|
max_node_size = np.max(pop_node_sizes)
|
||||||
if max_node_size >= self.N:
|
if max_node_size >= self.N:
|
||||||
self.N = int(self.N * self.expand_coe)
|
self.N = int(self.N * self.expand_coe)
|
||||||
print(f"expand to {self.N}!")
|
print(f"node expand to {self.N}!")
|
||||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
|
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||||
|
|
||||||
# don't forget to expand representation genome in species
|
# don't forget to expand representation genome in species
|
||||||
for s in self.species_controller.species.values():
|
for s in self.species_controller.species.values():
|
||||||
s.representative = expand_single(*s.representative, self.N)
|
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||||
|
|
||||||
# update functions
|
# update functions
|
||||||
self.compile_functions(debug=True)
|
self.compile_functions(debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
pop_con_keys = self.pop_connections[:, :, 0]
|
||||||
|
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||||
|
max_con_size = np.max(pop_node_sizes)
|
||||||
|
if max_con_size >= self.C:
|
||||||
|
self.C = int(self.C * self.expand_coe)
|
||||||
|
print(f"connections expand to {self.C}!")
|
||||||
|
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||||
|
|
||||||
|
# don't forget to expand representation genome in species
|
||||||
|
for s in self.species_controller.species.values():
|
||||||
|
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||||
|
|
||||||
|
# update functions
|
||||||
|
self.compile_functions(debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compile_functions(self, debug=False):
|
def compile_functions(self, debug=False):
|
||||||
self.mutate_func = self.function_factory.create_mutate(self.N)
|
self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
|
||||||
self.crossover_func = self.function_factory.create_crossover(self.N)
|
self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
|
||||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N)
|
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
|
||||||
|
|
||||||
def default_analysis(self, fitnesses):
|
def default_analysis(self, fitnesses):
|
||||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||||
|
|||||||
49
examples/function_tests.py
Normal file
49
examples/function_tests.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
from algorithms.neat.function_factory import FunctionFactory
|
||||||
|
from algorithms.neat.genome.debug.tools import check_array_valid
|
||||||
|
from utils import Configer
|
||||||
|
|
||||||
|
from algorithms.neat.genome.crossover import crossover
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Configer.load_config()
|
||||||
|
function_factory = FunctionFactory(config, debug=True)
|
||||||
|
initialize_func = function_factory.create_initialize()
|
||||||
|
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
|
||||||
|
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
|
||||||
|
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
new_node_idx = 100
|
||||||
|
while True:
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
mutate_keys = jax.random.split(subkey, len(pop_nodes))
|
||||||
|
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
|
||||||
|
new_node_idx += len(pop_nodes)
|
||||||
|
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||||
|
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||||
|
# for i in range(len(pop_nodes)):
|
||||||
|
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||||
|
idx1 = np.random.permutation(len(pop_nodes))
|
||||||
|
idx2 = np.random.permutation(len(pop_nodes))
|
||||||
|
|
||||||
|
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
||||||
|
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||||
|
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||||
|
|
||||||
|
# for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)):
|
||||||
|
# n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||||
|
# try:
|
||||||
|
# check_array_valid(n, c, input_idx, output_idx)
|
||||||
|
# except AssertionError as e:
|
||||||
|
# crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||||
|
|
||||||
|
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
||||||
|
|
||||||
|
for i in range(len(pop_nodes)):
|
||||||
|
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||||
|
|
||||||
|
print(new_node_idx)
|
||||||
|
|
||||||
|
|
||||||
@@ -1,11 +1,3 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# 输入
|
print(np.random.permutation(10))
|
||||||
a = np.array([1, 2, 3, 4])
|
|
||||||
b = np.array([5, 6])
|
|
||||||
|
|
||||||
# 创建一个网格,其中包含所有可能的组合
|
|
||||||
aa, bb = np.meshgrid(a, b)
|
|
||||||
aa = aa.flatten()
|
|
||||||
bb = bb.flatten()
|
|
||||||
print(aa, bb)
|
|
||||||
@@ -6,8 +6,8 @@ from time_utils import using_cprofile
|
|||||||
from problems import Sin, Xor, DIY
|
from problems import Sin, Xor, DIY
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
@using_cprofile
|
||||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
problem = Xor()
|
problem = Xor()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"problem_batch": 4,
|
"problem_batch": 4,
|
||||||
"init_maximum_nodes": 10,
|
"init_maximum_nodes": 10,
|
||||||
|
"init_maximum_connections": 10,
|
||||||
"expands_coe": 2,
|
"expands_coe": 2,
|
||||||
"pre_compile_times": 3,
|
"pre_compile_times": 3,
|
||||||
"forward_way": "pop_batch"
|
"forward_way": "pop_batch"
|
||||||
@@ -13,7 +14,7 @@
|
|||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": -0.001,
|
"fitness_threshold": -0.001,
|
||||||
"generation_limit": 1000,
|
"generation_limit": 1000,
|
||||||
"pop_size": 1000,
|
"pop_size": 5000,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
"gene": {
|
"gene": {
|
||||||
@@ -57,9 +58,9 @@
|
|||||||
"compatibility_weight_coefficient": 0.5,
|
"compatibility_weight_coefficient": 0.5,
|
||||||
"single_structural_mutation": "False",
|
"single_structural_mutation": "False",
|
||||||
"conn_add_prob": 0.5,
|
"conn_add_prob": 0.5,
|
||||||
"conn_delete_prob": 0,
|
"conn_delete_prob": 0.5,
|
||||||
"node_add_prob": 0.2,
|
"node_add_prob": 0.2,
|
||||||
"node_delete_prob": 0
|
"node_delete_prob": 0.2
|
||||||
},
|
},
|
||||||
"species": {
|
"species": {
|
||||||
"compatibility_threshold": 2.5,
|
"compatibility_threshold": 2.5,
|
||||||
|
|||||||
Reference in New Issue
Block a user