refactor genome.py use (C, 4) to replace (2, N, N) to represent connections
faster, faster and faster!
This commit is contained in:
@@ -8,7 +8,7 @@ import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
|
||||
from .genome import topological_sort, forward_single
|
||||
from .genome import topological_sort, forward_single, unflatten_connections
|
||||
|
||||
|
||||
class FunctionFactory:
|
||||
@@ -17,19 +17,18 @@ class FunctionFactory:
|
||||
self.debug = debug
|
||||
|
||||
self.init_N = config.basic.init_maximum_nodes
|
||||
self.init_C = config.basic.init_maximum_connections
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.precompile_times = config.basic.pre_compile_times
|
||||
self.compiled_function = {}
|
||||
|
||||
self.load_config_vals(config)
|
||||
self.precompile()
|
||||
pass
|
||||
|
||||
def load_config_vals(self, config):
|
||||
self.problem_batch = config.basic.problem_batch
|
||||
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
self.init_N = config.basic.init_maximum_nodes
|
||||
|
||||
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
@@ -85,6 +84,7 @@ class FunctionFactory:
|
||||
initialize_genomes,
|
||||
pop_size=self.pop_size,
|
||||
N=self.init_N,
|
||||
C=self.init_C,
|
||||
num_inputs=self.num_inputs,
|
||||
num_outputs=self.num_outputs,
|
||||
default_bias=self.bias_mean,
|
||||
@@ -107,24 +107,24 @@ class FunctionFactory:
|
||||
self.create_crossover_with_args()
|
||||
self.create_topological_sort_with_args()
|
||||
self.create_single_forward_with_args()
|
||||
|
||||
n = self.init_N
|
||||
print("start precompile")
|
||||
for _ in range(self.precompile_times):
|
||||
self.compile_mutate(n)
|
||||
self.compile_distance(n)
|
||||
self.compile_crossover(n)
|
||||
self.compile_topological_sort_batch(n)
|
||||
self.compile_pop_batch_forward(n)
|
||||
n = int(self.expand_coe * n)
|
||||
|
||||
# precompile other functions used in jax
|
||||
key = jax.random.PRNGKey(0)
|
||||
_ = jax.random.split(key, 3)
|
||||
_ = jax.random.split(key, self.pop_size * 2)
|
||||
_ = jax.random.split(key, self.pop_size)
|
||||
|
||||
print("end precompile")
|
||||
#
|
||||
# n, c = self.init_N, self.init_C
|
||||
# print("start precompile")
|
||||
# for _ in range(self.precompile_times):
|
||||
# self.compile_mutate(n)
|
||||
# self.compile_distance(n)
|
||||
# self.compile_crossover(n)
|
||||
# self.compile_topological_sort_batch(n)
|
||||
# self.compile_pop_batch_forward(n)
|
||||
# n = int(self.expand_coe * n)
|
||||
#
|
||||
# # precompile other functions used in jax
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# _ = jax.random.split(key, 3)
|
||||
# _ = jax.random.split(key, self.pop_size * 2)
|
||||
# _ = jax.random.split(key, self.pop_size)
|
||||
#
|
||||
# print("end precompile")
|
||||
|
||||
def create_mutate_with_args(self):
|
||||
func = partial(
|
||||
@@ -161,20 +161,20 @@ class FunctionFactory:
|
||||
)
|
||||
self.mutate_with_args = func
|
||||
|
||||
def compile_mutate(self, n):
|
||||
def compile_mutate(self, n, c):
|
||||
func = self.mutate_with_args
|
||||
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
connections_lower = np.zeros((self.pop_size, c, 4))
|
||||
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
|
||||
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
|
||||
connections_lower, new_node_key_lower).compile()
|
||||
self.compiled_function[('mutate', n)] = batched_mutate_func
|
||||
self.compiled_function[('mutate', n, c)] = batched_mutate_func
|
||||
|
||||
def create_mutate(self, n):
|
||||
key = ('mutate', n)
|
||||
def create_mutate(self, n, c):
|
||||
key = ('mutate', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_mutate(n)
|
||||
self.compile_mutate(n, c)
|
||||
if self.debug:
|
||||
def debug_mutate(*args):
|
||||
res_nodes, res_connections = self.compiled_function[key](*args)
|
||||
@@ -192,28 +192,28 @@ class FunctionFactory:
|
||||
)
|
||||
self.distance_with_args = func
|
||||
|
||||
def compile_distance(self, n):
|
||||
def compile_distance(self, n, c):
|
||||
func = self.distance_with_args
|
||||
o2o_nodes1_lower = np.zeros((n, 5))
|
||||
o2o_connections1_lower = np.zeros((2, n, n))
|
||||
o2o_connections1_lower = np.zeros((c, 4))
|
||||
o2o_nodes2_lower = np.zeros((n, 5))
|
||||
o2o_connections2_lower = np.zeros((2, n, n))
|
||||
o2o_connections2_lower = np.zeros((c, 4))
|
||||
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2o_nodes2_lower, o2o_connections2_lower).compile()
|
||||
|
||||
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
|
||||
o2m_nodes2_lower,
|
||||
o2m_connections2_lower).compile()
|
||||
|
||||
self.compiled_function[('o2o_distance', n)] = o2o_distance
|
||||
self.compiled_function[('o2m_distance', n)] = o2m_distance
|
||||
self.compiled_function[('o2o_distance', n, c)] = o2o_distance
|
||||
self.compiled_function[('o2m_distance', n, c)] = o2m_distance
|
||||
|
||||
def create_distance(self, n):
|
||||
key1, key2 = ('o2o_distance', n), ('o2m_distance', n)
|
||||
def create_distance(self, n, c):
|
||||
key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
|
||||
if key1 not in self.compiled_function:
|
||||
self.compile_distance(n)
|
||||
self.compile_distance(n, c)
|
||||
if self.debug:
|
||||
|
||||
def debug_o2o_distance(*args):
|
||||
@@ -229,21 +229,21 @@ class FunctionFactory:
|
||||
def create_crossover_with_args(self):
|
||||
self.crossover_with_args = crossover
|
||||
|
||||
def compile_crossover(self, n):
|
||||
def compile_crossover(self, n, c):
|
||||
func = self.crossover_with_args
|
||||
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
|
||||
nodes1_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections1_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
connections1_lower = np.zeros((self.pop_size, c, 4))
|
||||
nodes2_lower = np.zeros((self.pop_size, n, 5))
|
||||
connections2_lower = np.zeros((self.pop_size, 2, n, n))
|
||||
connections2_lower = np.zeros((self.pop_size, c, 4))
|
||||
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
self.compiled_function[('crossover', n)] = func
|
||||
self.compiled_function[('crossover', n, c)] = func
|
||||
|
||||
def create_crossover(self, n):
|
||||
key = ('crossover', n)
|
||||
def create_crossover(self, n, c):
|
||||
key = ('crossover', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_crossover(n)
|
||||
self.compile_crossover(n, c)
|
||||
if self.debug:
|
||||
|
||||
def debug_crossover(*args):
|
||||
@@ -365,15 +365,17 @@ class FunctionFactory:
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_connections):
|
||||
n = pop_nodes.shape[1]
|
||||
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||
ts = self.create_topological_sort_batch(n)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_connections)
|
||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||
|
||||
forward_func = self.create_pop_batch_forward(n)
|
||||
|
||||
def debug_forward(inputs):
|
||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
|
||||
|
||||
return debug_forward
|
||||
|
||||
@@ -387,3 +389,23 @@ class FunctionFactory:
|
||||
return forward_func(inputs, cal_seqs, nodes, connections)
|
||||
|
||||
return debug_forward
|
||||
|
||||
def compile_batch_unflatten_connections(self, n, c):
|
||||
func = unflatten_connections
|
||||
func = vmap(func)
|
||||
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||
pop_connections_lower = np.zeros((self.pop_size, c, 4))
|
||||
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||
|
||||
def create_batch_unflatten_connections(self, n, c):
|
||||
key = ('batch_unflatten_connections', n, c)
|
||||
if key not in self.compiled_function:
|
||||
self.compile_batch_unflatten_connections(n, c)
|
||||
if self.debug:
|
||||
def debug_batch_unflatten_connections(*args):
|
||||
return self.compiled_function[key](*args).block_until_ready()
|
||||
|
||||
return debug_batch_unflatten_connections
|
||||
else:
|
||||
return self.compiled_function[key]
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .genome import expand, expand_single, pop_analysis, initialize_genomes
|
||||
from .forward import create_forward_function, forward_single
|
||||
from .genome import expand, expand_single, initialize_genomes
|
||||
from .forward import forward_single
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
from .crossover import crossover
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .graph import topological_sort
|
||||
from .utils import unflatten_connections
|
||||
@@ -23,8 +23,8 @@ def sin_act(z):
|
||||
|
||||
@jit
|
||||
def gauss_act(z):
|
||||
z = jnp.clip(z, -3.4, 3.4)
|
||||
return jnp.exp(-5 * z ** 2)
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -7,16 +7,16 @@ from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
:param randkey:
|
||||
:param nodes1:
|
||||
:param connections1:
|
||||
:param cons1:
|
||||
:param nodes2:
|
||||
:param connections2:
|
||||
:param cons2:
|
||||
:return:
|
||||
"""
|
||||
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||
@@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2]
|
||||
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
@@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
0
algorithms/neat/genome/debug/__init__.py
Normal file
0
algorithms/neat/genome/debug/__init__.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_array_valid(nodes, cons, input_keys, output_keys):
|
||||
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
|
||||
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
|
||||
|
||||
|
||||
def array2object(nodes, cons, input_keys, output_keys):
|
||||
"""
|
||||
Convert a genome from array to dict.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
||||
"""
|
||||
# update nodes_dict
|
||||
nodes_dict = {}
|
||||
for i, node in enumerate(nodes):
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
key = int(node[0])
|
||||
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||
|
||||
if key in input_keys:
|
||||
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
|
||||
nodes_dict[key] = (None,) * 4
|
||||
else:
|
||||
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
|
||||
bias = node[1]
|
||||
response = node[2]
|
||||
act = node[3]
|
||||
agg = node[4]
|
||||
nodes_dict[key] = (bias, response, act, agg)
|
||||
|
||||
# check nodes_dict
|
||||
for i in input_keys:
|
||||
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||
|
||||
for o in output_keys:
|
||||
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||
|
||||
# update connections
|
||||
cons_dict = {}
|
||||
for i, con in enumerate(cons):
|
||||
if np.all(np.isnan(con)):
|
||||
pass
|
||||
elif np.all(~np.isnan(con)):
|
||||
i_key = int(con[0])
|
||||
o_key = int(con[1])
|
||||
if (i_key, o_key) in cons_dict:
|
||||
assert False, f"Duplicate connection: {(i_key, o_key)}!"
|
||||
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
||||
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
||||
weight = con[2]
|
||||
enabled = (con[3] == 1)
|
||||
cons_dict[(i_key, o_key)] = (weight, enabled)
|
||||
else:
|
||||
assert False, f"Connection {i} must has all None or all non-None!"
|
||||
|
||||
return nodes_dict, cons_dict
|
||||
|
||||
|
||||
def is_DAG(edges):
|
||||
all_nodes = set()
|
||||
for a, b in edges:
|
||||
if a == b: # cycle
|
||||
return False
|
||||
all_nodes.union({a, b})
|
||||
|
||||
for node in all_nodes:
|
||||
visited = {n: False for n in all_nodes}
|
||||
def dfs(n):
|
||||
if visited[n]:
|
||||
return False
|
||||
visited[n] = True
|
||||
for a, b in edges:
|
||||
if a == n:
|
||||
if not dfs(b):
|
||||
return False
|
||||
return True
|
||||
|
||||
if not dfs(node):
|
||||
return False
|
||||
return True
|
||||
@@ -1,11 +1,11 @@
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
from .utils import EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
|
||||
compatibility_coe: float = 0.5) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
@@ -15,10 +15,6 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
|
||||
|
||||
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
|
||||
|
||||
# refactor connections
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
|
||||
return nd + cd
|
||||
|
||||
@@ -35,9 +31,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
nan_mask = jnp.isnan(nodes[:, 0])
|
||||
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
nd = batch_homologous_node_distance(fr, sr)
|
||||
@@ -50,8 +45,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
|
||||
@jit
|
||||
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
@@ -62,7 +57,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
cd = batch_homologous_connection_distance(fr, sr)
|
||||
|
||||
@@ -1,51 +1,12 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
from jax import jit, vmap
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .aggregations import agg
|
||||
from .activations import act
|
||||
from .graph import topological_sort, batch_topological_sort
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
|
||||
"""
|
||||
create forward function for different situations
|
||||
|
||||
:param nodes: shape (N, 5) or (pop_size, N, 5)
|
||||
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
|
||||
:param N:
|
||||
:param input_idx:
|
||||
:param output_idx:
|
||||
:param batch: using batch or not
|
||||
:param debug: debug mode
|
||||
:return:
|
||||
"""
|
||||
|
||||
if nodes.ndim == 2: # single genome
|
||||
cal_seqs = topological_sort(nodes, connections)
|
||||
if not batch:
|
||||
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
|
||||
cal_seqs, nodes, connections)
|
||||
else:
|
||||
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
|
||||
cal_seqs, nodes, connections)
|
||||
elif nodes.ndim == 3: # pop genome
|
||||
pop_cal_seqs = batch_topological_sort(nodes, connections)
|
||||
if not batch:
|
||||
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
|
||||
pop_cal_seqs, nodes, connections)
|
||||
else:
|
||||
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
|
||||
pop_cal_seqs, nodes, connections)
|
||||
else:
|
||||
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
||||
|
||||
|
||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||
@jit
|
||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||
input_idx: Array, output_idx: Array) -> Array:
|
||||
@@ -84,66 +45,3 @@ def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Ar
|
||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for batch_inputs shaped (batch_size, input_num)
|
||||
# nodes, connections are single genome
|
||||
#
|
||||
# :argument batch_inputs: (batch_size, input_num)
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument cal_seqs: (N, )
|
||||
# :argument nodes: (N, 5)
|
||||
# :argument connections: (2, N, N)
|
||||
#
|
||||
# :return (batch_size, output_num)
|
||||
# """
|
||||
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
||||
#
|
||||
#
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for single input shaped (input_num, )
|
||||
# pop_nodes, pop_connections are population of genomes
|
||||
#
|
||||
# :argument inputs: (input_num, )
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument pop_cal_seqs: (pop_size, N)
|
||||
# :argument pop_nodes: (pop_size, N, 5)
|
||||
# :argument pop_connections: (pop_size, 2, N, N)
|
||||
#
|
||||
# :return (pop_size, output_num)
|
||||
# """
|
||||
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
#
|
||||
#
|
||||
# @partial(jit, static_argnames=['N'])
|
||||
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
# """
|
||||
# jax forward for batch input shaped (batch, input_num)
|
||||
# pop_nodes, pop_connections are population of genomes
|
||||
#
|
||||
# :argument batch_inputs: (batch_size, input_num)
|
||||
# :argument N: int
|
||||
# :argument input_idx: (input_num, )
|
||||
# :argument output_idx: (output_num, )
|
||||
# :argument pop_cal_seqs: (pop_size, N)
|
||||
# :argument pop_nodes: (pop_size, N, 5)
|
||||
# :argument pop_connections: (pop_size, 2, N, N)
|
||||
#
|
||||
# :return (pop_size, batch_size, output_num)
|
||||
# """
|
||||
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||
|
||||
@@ -20,7 +20,7 @@ from jax import numpy as jnp
|
||||
from jax import jit
|
||||
from jax import Array
|
||||
|
||||
from .utils import fetch_first, EMPTY_NODE
|
||||
from .utils import fetch_first
|
||||
|
||||
|
||||
def initialize_genomes(pop_size: int,
|
||||
@@ -124,79 +124,6 @@ def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tupl
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
|
||||
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
|
||||
"""
|
||||
Convert a genome from array to dict.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
||||
"""
|
||||
# update nodes_dict
|
||||
try:
|
||||
nodes_dict = {}
|
||||
for i, node in enumerate(nodes):
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
key = int(node[0])
|
||||
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||
|
||||
bias = node[1] if not np.isnan(node[1]) else None
|
||||
response = node[2] if not np.isnan(node[2]) else None
|
||||
act = node[3] if not np.isnan(node[3]) else None
|
||||
agg = node[4] if not np.isnan(node[4]) else None
|
||||
nodes_dict[key] = (bias, response, act, agg)
|
||||
|
||||
# check nodes_dict
|
||||
for i in input_keys:
|
||||
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||
bias, response, act, agg = nodes_dict[i]
|
||||
assert bias is None and response is None and act is None and agg is None, \
|
||||
f"Input node {i} must has None bias, response, act, or agg!"
|
||||
|
||||
for o in output_keys:
|
||||
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||
|
||||
for k, v in nodes_dict.items():
|
||||
if k not in input_keys:
|
||||
bias, response, act, agg = v
|
||||
assert bias is not None and response is not None and act is not None and agg is not None, \
|
||||
f"Normal node {k} must has non-None bias, response, act, or agg!"
|
||||
|
||||
# update connections
|
||||
cons_dict = {}
|
||||
for i, con in enumerate(cons):
|
||||
if np.isnan(con[0]):
|
||||
continue
|
||||
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
|
||||
i_key = int(con[0])
|
||||
o_key = int(con[1])
|
||||
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
||||
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
||||
key = (i_key, o_key)
|
||||
weight = con[2] if not np.isnan(con[2]) else None
|
||||
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
|
||||
assert weight is not None, f"Connection {key} must has non-None weight!"
|
||||
assert enabled is not None, f"Connection {key} must has non-None enabled!"
|
||||
|
||||
cons_dict[key] = (weight, enabled)
|
||||
|
||||
return nodes_dict, cons_dict
|
||||
except AssertionError:
|
||||
print(nodes)
|
||||
print(cons)
|
||||
raise AssertionError
|
||||
|
||||
|
||||
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
|
||||
res = []
|
||||
for nodes, cons in zip(pop_nodes, pop_cons):
|
||||
res.append(analysis(nodes, cons, input_keys, output_keys))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def count(nodes, cons):
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
@@ -231,7 +158,7 @@ def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Arra
|
||||
"""
|
||||
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
||||
"""
|
||||
nodes = nodes.at[idx].set(EMPTY_NODE)
|
||||
nodes = nodes.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@@ -243,7 +170,7 @@ def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
|
||||
"""
|
||||
con_keys = cons[:, 0]
|
||||
idx = fetch_first(jnp.isnan(con_keys))
|
||||
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
|
||||
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -6,11 +6,9 @@ import numpy as np
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, vmap, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT
|
||||
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('single_structure_mutate',))
|
||||
@@ -89,7 +87,7 @@ def mutate(rand_key: Array,
|
||||
return n, c
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default)
|
||||
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
@@ -153,7 +151,7 @@ def mutate(rand_key: Array,
|
||||
@jit
|
||||
def mutate_values(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
cons: Array,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
bias_mutate_strength: float = 0.5,
|
||||
@@ -180,7 +178,7 @@ def mutate_values(rand_key: Array,
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
nodes: A 2D array representing nodes.
|
||||
connections: A 3D array representing connections.
|
||||
cons: A 3D array representing connections.
|
||||
bias_mean: Mean of the bias values.
|
||||
bias_std: Standard deviation of the bias values.
|
||||
bias_mutate_strength: Strength of the bias mutation.
|
||||
@@ -211,24 +209,23 @@ def mutate_values(rand_key: Array,
|
||||
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
|
||||
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
||||
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
|
||||
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
|
||||
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
||||
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
|
||||
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
|
||||
|
||||
# refactor enabled
|
||||
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
|
||||
enabled_new = connections[1, :, :] == 1
|
||||
enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new)
|
||||
enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan)
|
||||
# mutate enabled
|
||||
r = jax.random.uniform(rand_key, cons[:, 3].shape)
|
||||
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
|
||||
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
|
||||
|
||||
nodes = nodes.at[:, 1].set(bias_new)
|
||||
nodes = nodes.at[:, 2].set(response_new)
|
||||
nodes = nodes.at[:, 3].set(act_new)
|
||||
nodes = nodes.at[:, 4].set(agg_new)
|
||||
connections = connections.at[0, :, :].set(weight_new)
|
||||
connections = connections.at[1, :, :].set(enabled_new)
|
||||
return nodes, connections
|
||||
cons = cons.at[:, 2].set(weight_new)
|
||||
cons = cons.at[:, 3].set(enabled_new)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
@@ -288,7 +285,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
|
||||
default_bias: float = 0, default_response: float = 1,
|
||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -296,7 +293,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
:param rand_key:
|
||||
:param new_node_key:
|
||||
:param nodes:
|
||||
:param connections:
|
||||
:param cons:
|
||||
:param default_bias:
|
||||
:param default_response:
|
||||
:param default_act:
|
||||
@@ -304,44 +301,42 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
def nothing(): # there is no connection to split
|
||||
return nodes, cons
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_nodes, new_connections = nodes, connections
|
||||
new_connections = new_connections.at[1, from_idx, to_idx].set(False)
|
||||
new_nodes, new_cons = nodes, cons
|
||||
new_cons = new_cons.at[idx, 3].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes, new_connections = \
|
||||
add_node(new_node_key, new_nodes, new_connections,
|
||||
new_nodes, new_cons = \
|
||||
add_node(new_nodes, new_cons, new_node_key,
|
||||
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
|
||||
|
||||
# add two new connections
|
||||
weight = new_connections[0, from_idx, to_idx]
|
||||
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
|
||||
new_nodes, new_connections, weight=1., enabled=True)
|
||||
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
|
||||
new_nodes, new_connections, weight=weight, enabled=True)
|
||||
return new_nodes, new_connections
|
||||
w = new_cons[idx, 2]
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
|
||||
# TODO: Need we really need to delete a node?
|
||||
@jit
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param connections:
|
||||
:param cons:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:return:
|
||||
@@ -351,83 +346,86 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
|
||||
|
||||
# delete connections
|
||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||
# delete all connections
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
|
||||
jnp.nan, aux_cons)
|
||||
|
||||
return aux_nodes, aux_connections
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||
cycles are not allowed.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param connections:
|
||||
:param cons:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
def successful():
|
||||
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
return new_nodes, new_connections
|
||||
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
def already_exist():
|
||||
new_connections = connections.at[1, from_idx, to_idx].set(True)
|
||||
return nodes, new_connections
|
||||
new_cons = cons.at[con_idx, 3].set(True)
|
||||
return nodes, new_cons
|
||||
|
||||
def cycle():
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx])
|
||||
is_cycle = check_cycles(nodes, connections, from_idx, to_idx)
|
||||
is_already_exist = con_idx != I_INT
|
||||
unflattened = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
return nodes, connections
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param connections:
|
||||
:param cons:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, connections
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
@@ -460,31 +458,20 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
||||
|
||||
|
||||
@jit
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param connection:
|
||||
:return: from_key, to_key, from_idx, to_idx
|
||||
:param cons:
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
|
||||
|
||||
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
|
||||
|
||||
def nothing():
|
||||
return jnp.nan, jnp.nan, I_INT, I_INT
|
||||
|
||||
def has_connection():
|
||||
f_idx = fetch_random(k1, has_connections_row)
|
||||
col = connection[0, f_idx, :]
|
||||
t_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
|
||||
return f_key, t_key, f_idx, t_idx
|
||||
|
||||
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
return i_key, o_key, idx
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -3,84 +3,38 @@ from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit
|
||||
from jax import jit, vmap
|
||||
|
||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def flatten_connections(keys, connections):
|
||||
def unflatten_connections(nodes, cons):
|
||||
"""
|
||||
flatten the (2, N, N) connections to (N * N, 4)
|
||||
:param keys:
|
||||
:param connections:
|
||||
:return:
|
||||
the first two columns are the index of the node
|
||||
the 3rd column is the weight, and the 4th column is the enabled status
|
||||
"""
|
||||
indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij')
|
||||
indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
|
||||
|
||||
# make (2, N, N) to (N, N, 2)
|
||||
con = jnp.transpose(connections, (1, 2, 0))
|
||||
# make (N, N, 2) to (N * N, 2)
|
||||
con = jnp.reshape(con, (-1, 2))
|
||||
|
||||
con = jnp.concatenate((indices, con), axis=1)
|
||||
return con
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['N'])
|
||||
def unflatten_connections(N, cons):
|
||||
"""
|
||||
restore the (N * N, 4) connections to (2, N, N)
|
||||
:param N:
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
:param cons:
|
||||
:param nodes:
|
||||
:return:
|
||||
"""
|
||||
cons = cons[:, 2:] # remove the indices
|
||||
unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0)
|
||||
return unflatten_cons
|
||||
N = nodes.shape[0]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_idxs = key_to_indices(i_keys, node_keys)
|
||||
o_idxs = key_to_indices(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
|
||||
intersection mask, and union mask.
|
||||
|
||||
:param ar1: JAX array of shape (N, M)
|
||||
First input array. Should have the same shape as ar2.
|
||||
:param ar2: JAX array of shape (N, M)
|
||||
Second input array. Should have the same shape as ar1.
|
||||
:return: tuple of 3 arrays
|
||||
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
|
||||
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
|
||||
in the sorted concatenation.
|
||||
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
|
||||
in the sorted concatenation.
|
||||
|
||||
Examples:
|
||||
a = jnp.array([[1, 2], [3, 4], [5, 6]])
|
||||
b = jnp.array([[1, 2], [7, 8], [9, 10]])
|
||||
|
||||
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
|
||||
|
||||
sorted_indices -> array([0, 1, 2, 3, 4, 5])
|
||||
intersect_mask -> array([True, False, False, False, False, False])
|
||||
union_mask -> array([False, True, True, True, True, True])
|
||||
"""
|
||||
ar = jnp.concatenate((ar1, ar2), axis=0)
|
||||
sorted_indices = jnp.lexsort(ar.T[::-1])
|
||||
aux = ar[sorted_indices]
|
||||
aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0)
|
||||
nan_mask = jnp.any(jnp.isnan(aux), axis=1)
|
||||
|
||||
fr, sr = aux[:-1], aux[1:] # first row, second row
|
||||
intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1]
|
||||
union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1]
|
||||
return sorted_indices, intersect_mask, union_mask
|
||||
@partial(vmap, in_axes=(0, None))
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -8,6 +8,7 @@ from .species import SpeciesController
|
||||
from .genome import expand, expand_single
|
||||
from .function_factory import FunctionFactory
|
||||
from .genome.genome import count
|
||||
from .genome.debug.tools import check_array_valid
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
@@ -23,6 +24,7 @@ class Pipeline:
|
||||
|
||||
self.config = config
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
self.C = config.basic.init_maximum_connections
|
||||
self.expand_coe = config.basic.expands_coe
|
||||
self.pop_size = config.neat.population.pop_size
|
||||
|
||||
@@ -57,6 +59,8 @@ class Pipeline:
|
||||
|
||||
self.update_next_generation(winner_part, loser_part, elite_mask)
|
||||
|
||||
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
|
||||
self.o2o_distance, self.o2m_distance)
|
||||
|
||||
@@ -105,16 +109,25 @@ class Pipeline:
|
||||
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||
lpc) # new pop nodes, new pop connections
|
||||
|
||||
# for i in range(self.pop_size):
|
||||
# n, c = np.array(npn[i]), np.array(npc[i])
|
||||
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||
|
||||
# mutate
|
||||
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
|
||||
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# for i in range(self.pop_size):
|
||||
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
|
||||
# check_array_valid(n, c, self.input_idx, self.output_idx)
|
||||
|
||||
# elitism don't mutate
|
||||
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||
|
||||
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
|
||||
self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc)
|
||||
self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
@@ -128,20 +141,38 @@ class Pipeline:
|
||||
max_node_size = np.max(pop_node_sizes)
|
||||
if max_node_size >= self.N:
|
||||
self.N = int(self.N * self.expand_coe)
|
||||
print(f"expand to {self.N}!")
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
|
||||
print(f"node expand to {self.N}!")
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N)
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
# update functions
|
||||
self.compile_functions(debug=True)
|
||||
|
||||
|
||||
pop_con_keys = self.pop_connections[:, :, 0]
|
||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||
max_con_size = np.max(pop_node_sizes)
|
||||
if max_con_size >= self.C:
|
||||
self.C = int(self.C * self.expand_coe)
|
||||
print(f"connections expand to {self.C}!")
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N, self.C)
|
||||
|
||||
# update functions
|
||||
self.compile_functions(debug=True)
|
||||
|
||||
|
||||
|
||||
def compile_functions(self, debug=False):
|
||||
self.mutate_func = self.function_factory.create_mutate(self.N)
|
||||
self.crossover_func = self.function_factory.create_crossover(self.N)
|
||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N)
|
||||
self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
|
||||
self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
|
||||
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
49
examples/function_tests.py
Normal file
49
examples/function_tests.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
from algorithms.neat.function_factory import FunctionFactory
|
||||
from algorithms.neat.genome.debug.tools import check_array_valid
|
||||
from utils import Configer
|
||||
|
||||
from algorithms.neat.genome.crossover import crossover
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Configer.load_config()
|
||||
function_factory = FunctionFactory(config, debug=True)
|
||||
initialize_func = function_factory.create_initialize()
|
||||
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
|
||||
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
|
||||
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
|
||||
key = jax.random.PRNGKey(0)
|
||||
new_node_idx = 100
|
||||
while True:
|
||||
key, subkey = jax.random.split(key)
|
||||
mutate_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
|
||||
new_node_idx += len(pop_nodes)
|
||||
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
||||
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
||||
# for i in range(len(pop_nodes)):
|
||||
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
idx1 = np.random.permutation(len(pop_nodes))
|
||||
idx2 = np.random.permutation(len(pop_nodes))
|
||||
|
||||
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
||||
|
||||
# for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)):
|
||||
# n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||
# try:
|
||||
# check_array_valid(n, c, input_idx, output_idx)
|
||||
# except AssertionError as e:
|
||||
# crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
|
||||
|
||||
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
||||
|
||||
for i in range(len(pop_nodes)):
|
||||
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
||||
|
||||
print(new_node_idx)
|
||||
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
import numpy as np
|
||||
|
||||
# 输入
|
||||
a = np.array([1, 2, 3, 4])
|
||||
b = np.array([5, 6])
|
||||
|
||||
# 创建一个网格,其中包含所有可能的组合
|
||||
aa, bb = np.meshgrid(a, b)
|
||||
aa = aa.flatten()
|
||||
bb = bb.flatten()
|
||||
print(aa, bb)
|
||||
print(np.random.permutation(10))
|
||||
@@ -6,8 +6,8 @@ from time_utils import using_cprofile
|
||||
from problems import Sin, Xor, DIY
|
||||
|
||||
|
||||
# @using_cprofile
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
@using_cprofile
|
||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
problem = Xor()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"num_outputs": 1,
|
||||
"problem_batch": 4,
|
||||
"init_maximum_nodes": 10,
|
||||
"init_maximum_connections": 10,
|
||||
"expands_coe": 2,
|
||||
"pre_compile_times": 3,
|
||||
"forward_way": "pop_batch"
|
||||
@@ -13,7 +14,7 @@
|
||||
"fitness_criterion": "max",
|
||||
"fitness_threshold": -0.001,
|
||||
"generation_limit": 1000,
|
||||
"pop_size": 1000,
|
||||
"pop_size": 5000,
|
||||
"reset_on_extinction": "False"
|
||||
},
|
||||
"gene": {
|
||||
@@ -57,9 +58,9 @@
|
||||
"compatibility_weight_coefficient": 0.5,
|
||||
"single_structural_mutation": "False",
|
||||
"conn_add_prob": 0.5,
|
||||
"conn_delete_prob": 0,
|
||||
"conn_delete_prob": 0.5,
|
||||
"node_add_prob": 0.2,
|
||||
"node_delete_prob": 0
|
||||
"node_delete_prob": 0.2
|
||||
},
|
||||
"species": {
|
||||
"compatibility_threshold": 2.5,
|
||||
|
||||
Reference in New Issue
Block a user