refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

faster, faster and faster!
This commit is contained in:
wls2002
2023-05-12 00:57:55 +08:00
parent e5fc1167d9
commit 47b1a1dbb2
16 changed files with 363 additions and 419 deletions

View File

@@ -8,7 +8,7 @@ import numpy as np
from jax import jit, vmap from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import topological_sort, forward_single from .genome import topological_sort, forward_single, unflatten_connections
class FunctionFactory: class FunctionFactory:
@@ -17,19 +17,18 @@ class FunctionFactory:
self.debug = debug self.debug = debug
self.init_N = config.basic.init_maximum_nodes self.init_N = config.basic.init_maximum_nodes
self.init_C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {} self.compiled_function = {}
self.load_config_vals(config) self.load_config_vals(config)
self.precompile() self.precompile()
pass
def load_config_vals(self, config): def load_config_vals(self, config):
self.problem_batch = config.basic.problem_batch self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient self.disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient self.compatibility_coe = config.neat.genome.compatibility_weight_coefficient
@@ -85,6 +84,7 @@ class FunctionFactory:
initialize_genomes, initialize_genomes,
pop_size=self.pop_size, pop_size=self.pop_size,
N=self.init_N, N=self.init_N,
C=self.init_C,
num_inputs=self.num_inputs, num_inputs=self.num_inputs,
num_outputs=self.num_outputs, num_outputs=self.num_outputs,
default_bias=self.bias_mean, default_bias=self.bias_mean,
@@ -107,24 +107,24 @@ class FunctionFactory:
self.create_crossover_with_args() self.create_crossover_with_args()
self.create_topological_sort_with_args() self.create_topological_sort_with_args()
self.create_single_forward_with_args() self.create_single_forward_with_args()
#
n = self.init_N # n, c = self.init_N, self.init_C
print("start precompile") # print("start precompile")
for _ in range(self.precompile_times): # for _ in range(self.precompile_times):
self.compile_mutate(n) # self.compile_mutate(n)
self.compile_distance(n) # self.compile_distance(n)
self.compile_crossover(n) # self.compile_crossover(n)
self.compile_topological_sort_batch(n) # self.compile_topological_sort_batch(n)
self.compile_pop_batch_forward(n) # self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n) # n = int(self.expand_coe * n)
#
# precompile other functions used in jax # # precompile other functions used in jax
key = jax.random.PRNGKey(0) # key = jax.random.PRNGKey(0)
_ = jax.random.split(key, 3) # _ = jax.random.split(key, 3)
_ = jax.random.split(key, self.pop_size * 2) # _ = jax.random.split(key, self.pop_size * 2)
_ = jax.random.split(key, self.pop_size) # _ = jax.random.split(key, self.pop_size)
#
print("end precompile") # print("end precompile")
def create_mutate_with_args(self): def create_mutate_with_args(self):
func = partial( func = partial(
@@ -161,20 +161,20 @@ class FunctionFactory:
) )
self.mutate_with_args = func self.mutate_with_args = func
def compile_mutate(self, n): def compile_mutate(self, n, c):
func = self.mutate_with_args func = self.mutate_with_args
rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) rand_key_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes_lower = np.zeros((self.pop_size, n, 5)) nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n)) connections_lower = np.zeros((self.pop_size, c, 4))
new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32) new_node_key_lower = np.zeros((self.pop_size,), dtype=np.int32)
batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower, batched_mutate_func = jit(vmap(func)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile() connections_lower, new_node_key_lower).compile()
self.compiled_function[('mutate', n)] = batched_mutate_func self.compiled_function[('mutate', n, c)] = batched_mutate_func
def create_mutate(self, n): def create_mutate(self, n, c):
key = ('mutate', n) key = ('mutate', n, c)
if key not in self.compiled_function: if key not in self.compiled_function:
self.compile_mutate(n) self.compile_mutate(n, c)
if self.debug: if self.debug:
def debug_mutate(*args): def debug_mutate(*args):
res_nodes, res_connections = self.compiled_function[key](*args) res_nodes, res_connections = self.compiled_function[key](*args)
@@ -192,28 +192,28 @@ class FunctionFactory:
) )
self.distance_with_args = func self.distance_with_args = func
def compile_distance(self, n): def compile_distance(self, n, c):
func = self.distance_with_args func = self.distance_with_args
o2o_nodes1_lower = np.zeros((n, 5)) o2o_nodes1_lower = np.zeros((n, 5))
o2o_connections1_lower = np.zeros((2, n, n)) o2o_connections1_lower = np.zeros((c, 4))
o2o_nodes2_lower = np.zeros((n, 5)) o2o_nodes2_lower = np.zeros((n, 5))
o2o_connections2_lower = np.zeros((2, n, n)) o2o_connections2_lower = np.zeros((c, 4))
o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower, o2o_distance = jit(func).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2o_nodes2_lower, o2o_connections2_lower).compile() o2o_nodes2_lower, o2o_connections2_lower).compile()
o2m_nodes2_lower = np.zeros((self.pop_size, n, 5)) o2m_nodes2_lower = np.zeros((self.pop_size, n, 5))
o2m_connections2_lower = np.zeros((self.pop_size, 2, n, n)) o2m_connections2_lower = np.zeros((self.pop_size, c, 4))
o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower, o2m_distance = jit(vmap(func, in_axes=(None, None, 0, 0))).lower(o2o_nodes1_lower, o2o_connections1_lower,
o2m_nodes2_lower, o2m_nodes2_lower,
o2m_connections2_lower).compile() o2m_connections2_lower).compile()
self.compiled_function[('o2o_distance', n)] = o2o_distance self.compiled_function[('o2o_distance', n, c)] = o2o_distance
self.compiled_function[('o2m_distance', n)] = o2m_distance self.compiled_function[('o2m_distance', n, c)] = o2m_distance
def create_distance(self, n): def create_distance(self, n, c):
key1, key2 = ('o2o_distance', n), ('o2m_distance', n) key1, key2 = ('o2o_distance', n, c), ('o2m_distance', n, c)
if key1 not in self.compiled_function: if key1 not in self.compiled_function:
self.compile_distance(n) self.compile_distance(n, c)
if self.debug: if self.debug:
def debug_o2o_distance(*args): def debug_o2o_distance(*args):
@@ -229,21 +229,21 @@ class FunctionFactory:
def create_crossover_with_args(self): def create_crossover_with_args(self):
self.crossover_with_args = crossover self.crossover_with_args = crossover
def compile_crossover(self, n): def compile_crossover(self, n, c):
func = self.crossover_with_args func = self.crossover_with_args
randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32) randkey_lower = np.zeros((self.pop_size, 2), dtype=np.uint32)
nodes1_lower = np.zeros((self.pop_size, n, 5)) nodes1_lower = np.zeros((self.pop_size, n, 5))
connections1_lower = np.zeros((self.pop_size, 2, n, n)) connections1_lower = np.zeros((self.pop_size, c, 4))
nodes2_lower = np.zeros((self.pop_size, n, 5)) nodes2_lower = np.zeros((self.pop_size, n, 5))
connections2_lower = np.zeros((self.pop_size, 2, n, n)) connections2_lower = np.zeros((self.pop_size, c, 4))
func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower, func = jit(vmap(func)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile() nodes2_lower, connections2_lower).compile()
self.compiled_function[('crossover', n)] = func self.compiled_function[('crossover', n, c)] = func
def create_crossover(self, n): def create_crossover(self, n, c):
key = ('crossover', n) key = ('crossover', n, c)
if key not in self.compiled_function: if key not in self.compiled_function:
self.compile_crossover(n) self.compile_crossover(n, c)
if self.debug: if self.debug:
def debug_crossover(*args): def debug_crossover(*args):
@@ -365,15 +365,17 @@ class FunctionFactory:
else: else:
return self.compiled_function[key] return self.compiled_function[key]
def ask_pop_batch_forward(self, pop_nodes, pop_connections): def ask_pop_batch_forward(self, pop_nodes, pop_cons):
n = pop_nodes.shape[1] n, c = pop_nodes.shape[1], pop_cons.shape[1]
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
ts = self.create_topological_sort_batch(n) ts = self.create_topological_sort_batch(n)
pop_cal_seqs = ts(pop_nodes, pop_connections) pop_cal_seqs = ts(pop_nodes, pop_cons)
forward_func = self.create_pop_batch_forward(n) forward_func = self.create_pop_batch_forward(n)
def debug_forward(inputs): def debug_forward(inputs):
return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections) return forward_func(inputs, pop_cal_seqs, pop_nodes, pop_cons)
return debug_forward return debug_forward
@@ -387,3 +389,23 @@ class FunctionFactory:
return forward_func(inputs, cal_seqs, nodes, connections) return forward_func(inputs, cal_seqs, nodes, connections)
return debug_forward return debug_forward
def compile_batch_unflatten_connections(self, n, c):
func = unflatten_connections
func = vmap(func)
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
pop_connections_lower = np.zeros((self.pop_size, c, 4))
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
self.compiled_function[('batch_unflatten_connections', n, c)] = func
def create_batch_unflatten_connections(self, n, c):
key = ('batch_unflatten_connections', n, c)
if key not in self.compiled_function:
self.compile_batch_unflatten_connections(n, c)
if self.debug:
def debug_batch_unflatten_connections(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_batch_unflatten_connections
else:
return self.compiled_function[key]

View File

@@ -1,8 +1,9 @@
from .genome import expand, expand_single, pop_analysis, initialize_genomes from .genome import expand, expand_single, initialize_genomes
from .forward import create_forward_function, forward_single from .forward import forward_single
from .activations import act_name2key from .activations import act_name2key
from .aggregations import agg_name2key from .aggregations import agg_name2key
from .crossover import crossover from .crossover import crossover
from .mutate import mutate from .mutate import mutate
from .distance import distance from .distance import distance
from .graph import topological_sort from .graph import topological_sort
from .utils import unflatten_connections

View File

@@ -23,8 +23,8 @@ def sin_act(z):
@jit @jit
def gauss_act(z): def gauss_act(z):
z = jnp.clip(z, -3.4, 3.4) z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-5 * z ** 2) return jnp.exp(-z ** 2)
@jit @jit

View File

@@ -7,16 +7,16 @@ from jax import numpy as jnp
@jit @jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]: -> Tuple[Array, Array]:
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey: :param randkey:
:param nodes1: :param nodes1:
:param connections1: :param cons1:
:param nodes2: :param nodes2:
:param connections2: :param cons2:
:return: :return:
""" """
randkey_1, randkey_2 = jax.random.split(randkey) randkey_1, randkey_2 = jax.random.split(randkey)
@@ -27,15 +27,14 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2)) new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections # crossover connections
con_keys1, con_keys2 = connections1[:, :2], connections2[:, :2] con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
connections2 = align_array(con_keys1, con_keys2, connections2, 'connection') cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(connections1) | jnp.isnan(connections1), cons1, crossover_gene(randkey_2, cons1, cons2)) new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type']) # @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
""" """
After I review this code, I found that it is the most difficult part of the code. Please never change it! After I review this code, I found that it is the most difficult part of the code. Please never change it!
@@ -63,7 +62,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2 return refactor_ar2
@jit # @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
""" """
crossover two genes crossover two genes

View File

View File

@@ -0,0 +1,88 @@
from collections import defaultdict
import numpy as np
def check_array_valid(nodes, cons, input_keys, output_keys):
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
def array2object(nodes, cons, input_keys, output_keys):
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
if key in input_keys:
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
nodes_dict[key] = (None,) * 4
else:
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
bias = node[1]
response = node[2]
act = node[3]
agg = node[4]
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.all(np.isnan(con)):
pass
elif np.all(~np.isnan(con)):
i_key = int(con[0])
o_key = int(con[1])
if (i_key, o_key) in cons_dict:
assert False, f"Duplicate connection: {(i_key, o_key)}!"
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
weight = con[2]
enabled = (con[3] == 1)
cons_dict[(i_key, o_key)] = (weight, enabled)
else:
assert False, f"Connection {i} must has all None or all non-None!"
return nodes_dict, cons_dict
def is_DAG(edges):
all_nodes = set()
for a, b in edges:
if a == b: # cycle
return False
all_nodes.union({a, b})
for node in all_nodes:
visited = {n: False for n in all_nodes}
def dfs(n):
if visited[n]:
return False
visited[n] = True
for a, b in edges:
if a == n:
if not dfs(b):
return False
return True
if not dfs(node):
return False
return True

View File

@@ -1,11 +1,11 @@
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON from .utils import EMPTY_NODE, EMPTY_CON
@jit @jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array: compatibility_coe: float = 0.5) -> Array:
""" """
Calculate the distance between two genomes. Calculate the distance between two genomes.
@@ -15,10 +15,6 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
return nd + cd return nd + cd
@@ -35,9 +31,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
nodes = nodes[sorted_indices] nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row fr, sr = nodes[:-1], nodes[1:] # first row, second row
nan_mask = jnp.isnan(nodes[:, 0])
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1] intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr) nd = batch_homologous_node_distance(fr, sr)
@@ -50,8 +45,8 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
@jit @jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5): def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2])) con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2) max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0) cons = jnp.concatenate((cons1, cons2), axis=0)
@@ -62,7 +57,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
fr, sr = cons[:-1], cons[1:] # first row, second row fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection # both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2]) intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr) cd = batch_homologous_connection_distance(fr, sr)

View File

@@ -1,51 +1,12 @@
from functools import partial
import jax import jax
from jax import Array, numpy as jnp from jax import Array, numpy as jnp
from jax import jit, vmap from jax import jit, vmap
from numpy.typing import NDArray
from .aggregations import agg from .aggregations import agg
from .activations import act from .activations import act
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
:param nodes: shape (N, 5) or (pop_size, N, 5)
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
:param N:
:param input_idx:
:param output_idx:
:param batch: using batch or not
:param debug: debug mode
:return:
"""
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
else:
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
elif nodes.ndim == 3: # pop genome
pop_cal_seqs = batch_topological_sort(nodes, connections)
if not batch:
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
@jit @jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array, def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array: input_idx: Array, output_idx: Array) -> Array:
@@ -84,66 +45,3 @@ def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Ar
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx] return vals[output_idx]
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
# """
# jax forward for batch_inputs shaped (batch_size, input_num)
# nodes, connections are single genome
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument cal_seqs: (N, )
# :argument nodes: (N, 5)
# :argument connections: (2, N, N)
#
# :return (batch_size, output_num)
# """
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for single input shaped (input_num, )
# pop_nodes, pop_connections are population of genomes
#
# :argument inputs: (input_num, )
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, output_num)
# """
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for batch input shaped (batch, input_num)
# pop_nodes, pop_connections are population of genomes
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, batch_size, output_num)
# """
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)

View File

@@ -20,7 +20,7 @@ from jax import numpy as jnp
from jax import jit from jax import jit
from jax import Array from jax import Array
from .utils import fetch_first, EMPTY_NODE from .utils import fetch_first
def initialize_genomes(pop_size: int, def initialize_genomes(pop_size: int,
@@ -124,79 +124,6 @@ def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tupl
return new_nodes, new_cons return new_nodes, new_cons
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
bias = node[1] if not np.isnan(node[1]) else None
response = node[2] if not np.isnan(node[2]) else None
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
bias, response, act, agg = nodes_dict[i]
assert bias is None and response is None and act is None and agg is None, \
f"Input node {i} must has None bias, response, act, or agg!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
for k, v in nodes_dict.items():
if k not in input_keys:
bias, response, act, agg = v
assert bias is not None and response is not None and act is not None and agg is not None, \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.isnan(con[0]):
continue
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
i_key = int(con[0])
o_key = int(con[1])
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
key = (i_key, o_key)
weight = con[2] if not np.isnan(con[2]) else None
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
cons_dict[key] = (weight, enabled)
return nodes_dict, cons_dict
except AssertionError:
print(nodes)
print(cons)
raise AssertionError
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
res = []
for nodes, cons in zip(pop_nodes, pop_cons):
res.append(analysis(nodes, cons, input_keys, output_keys))
return res
@jit @jit
def count(nodes, cons): def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0])) node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
@@ -231,7 +158,7 @@ def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Arra
""" """
use idx to delete a node from the genome. only delete the node, regardless of connections. use idx to delete a node from the genome. only delete the node, regardless of connections.
""" """
nodes = nodes.at[idx].set(EMPTY_NODE) nodes = nodes.at[idx].set(np.nan)
return nodes, cons return nodes, cons
@@ -243,7 +170,7 @@ def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
""" """
con_keys = cons[:, 0] con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys)) idx = fetch_first(jnp.isnan(con_keys))
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled) return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
@jit @jit

View File

@@ -6,11 +6,9 @@ import numpy as np
from jax import numpy as jnp from jax import numpy as jnp
from jax import jit, vmap, Array from jax import jit, vmap, Array
from .utils import fetch_random, fetch_first, I_INT from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles from .graph import check_cycles
from .activations import act_name2key
from .aggregations import agg_name2key
@partial(jit, static_argnames=('single_structure_mutate',)) @partial(jit, static_argnames=('single_structure_mutate',))
@@ -89,7 +87,7 @@ def mutate(rand_key: Array,
return n, c return n, c
def m_add_node(rk, n, c): def m_add_node(rk, n, c):
return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default) return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_delete_node(rk, n, c): def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx) return mutate_delete_node(rk, n, c, input_idx, output_idx)
@@ -153,7 +151,7 @@ def mutate(rand_key: Array,
@jit @jit
def mutate_values(rand_key: Array, def mutate_values(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, cons: Array,
bias_mean: float = 0, bias_mean: float = 0,
bias_std: float = 1, bias_std: float = 1,
bias_mutate_strength: float = 0.5, bias_mutate_strength: float = 0.5,
@@ -180,7 +178,7 @@ def mutate_values(rand_key: Array,
Args: Args:
rand_key: A random key for generating random values. rand_key: A random key for generating random values.
nodes: A 2D array representing nodes. nodes: A 2D array representing nodes.
connections: A 3D array representing connections. cons: A 3D array representing connections.
bias_mean: Mean of the bias values. bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values. bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation. bias_mutate_strength: Strength of the bias mutation.
@@ -211,24 +209,23 @@ def mutate_values(rand_key: Array,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate) bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std, response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate) response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std, weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate) weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate) act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate) agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# refactor enabled # mutate enabled
r = jax.random.uniform(rand_key, connections[1, :, :].shape) r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = connections[1, :, :] == 1 enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new) enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan)
nodes = nodes.at[:, 1].set(bias_new) nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new) nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new) nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new) nodes = nodes.at[:, 4].set(agg_new)
connections = connections.at[0, :, :].set(weight_new) cons = cons.at[:, 2].set(weight_new)
connections = connections.at[1, :, :].set(enabled_new) cons = cons.at[:, 3].set(enabled_new)
return nodes, connections return nodes, cons
@jit @jit
@@ -288,7 +285,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
@jit @jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
default_bias: float = 0, default_response: float = 1, default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
""" """
@@ -296,7 +293,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
:param rand_key: :param rand_key:
:param new_node_key: :param new_node_key:
:param nodes: :param nodes:
:param connections: :param cons:
:param default_bias: :param default_bias:
:param default_response: :param default_response:
:param default_act: :param default_act:
@@ -304,44 +301,42 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
:return: :return:
""" """
# randomly choose a connection # randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): def nothing(): # there is no connection to split
return nodes, connections return nodes, cons
def successful_add_node(): def successful_add_node():
# disable the connection # disable the connection
new_nodes, new_connections = nodes, connections new_nodes, new_cons = nodes, cons
new_connections = new_connections.at[1, from_idx, to_idx].set(False) new_cons = new_cons.at[idx, 3].set(False)
# add a new node # add a new node
new_nodes, new_connections = \ new_nodes, new_cons = \
add_node(new_node_key, new_nodes, new_connections, add_node(new_nodes, new_cons, new_node_key,
bias=default_bias, response=default_response, act=default_act, agg=default_agg) bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
# add two new connections # add two new connections
weight = new_connections[0, from_idx, to_idx] w = new_cons[idx, 2]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_connections, weight=1., enabled=True) new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, return new_nodes, new_cons
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing # if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node) nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, connections return nodes, cons
# TODO: Need we really need to delete a node?
@jit @jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
Randomly delete a node. Input and output nodes are not allowed to be deleted. Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key: :param rand_key:
:param nodes: :param nodes:
:param connections: :param cons:
:param input_keys: :param input_keys:
:param output_keys: :param output_keys:
:return: :return:
@@ -351,83 +346,86 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
allow_input_keys=False, allow_output_keys=False) allow_input_keys=False, allow_output_keys=False)
def nothing(): def nothing():
return nodes, connections return nodes, cons
def successful_delete_node(): def successful_delete_node():
# delete the node # delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
# delete connections # delete all connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) jnp.nan, aux_cons)
return aux_nodes, aux_connections return aux_nodes, aux_cons
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node) nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, connections return nodes, cons
@jit @jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks, Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed. cycles are not allowed.
:param rand_key: :param rand_key:
:param nodes: :param nodes:
:param connections: :param cons:
:param input_keys: :param input_keys:
:param output_keys: :param output_keys:
:return: :return:
""" """
# randomly choose two nodes # randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2) k1, k2 = jax.random.split(rand_key, num=2)
from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys, i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True) allow_input_keys=True, allow_output_keys=True)
to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys, o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True) allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful(): def successful():
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections) new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_connections return new_nodes, new_cons
def already_exist(): def already_exist():
new_connections = connections.at[1, from_idx, to_idx].set(True) new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_connections return nodes, new_cons
def cycle(): def cycle():
return nodes, connections return nodes, cons
is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx]) is_already_exist = con_idx != I_INT
is_cycle = check_cycles(nodes, connections, from_idx, to_idx) unflattened = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful]) nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, connections return nodes, cons
@jit @jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
""" """
Randomly delete a connection. Randomly delete a connection.
:param rand_key: :param rand_key:
:param nodes: :param nodes:
:param connections: :param cons:
:return: :return:
""" """
# randomly choose a connection # randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections) i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): def nothing():
return nodes, connections return nodes, cons
def successfully_delete_connection(): def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections) return delete_connection_by_idx(nodes, cons, idx)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection) nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections return nodes, cons
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
@@ -460,31 +458,20 @@ def choice_node_key(rand_key: Array, nodes: Array,
@jit @jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
""" """
Randomly choose a connection key from the given connections. Randomly choose a connection key from the given connections.
:param rand_key: :param rand_key:
:param nodes: :param nodes:
:param connection: :param cons:
:return: from_key, to_key, from_idx, to_idx :return: i_key, o_key, idx
""" """
k1, k2 = jax.random.split(rand_key, num=2) idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1) return i_key, o_key, idx
def nothing():
return jnp.nan, jnp.nan, I_INT, I_INT
def has_connection():
f_idx = fetch_random(k1, has_connections_row)
col = connection[0, f_idx, :]
t_idx = fetch_random(k2, ~jnp.isnan(col))
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
return f_key, t_key, f_idx, t_idx
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
return from_key, to_key, from_idx, to_idx
@jit @jit

View File

@@ -3,84 +3,38 @@ from typing import Tuple
import jax import jax
from jax import numpy as jnp, Array from jax import numpy as jnp, Array
from jax import jit from jax import jit, vmap
I_INT = jnp.iinfo(jnp.int32).max # infinite int I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan) EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan) EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit @jit
def flatten_connections(keys, connections): def unflatten_connections(nodes, cons):
""" """
flatten the (2, N, N) connections to (N * N, 4) transform the (C, 4) connections to (2, N, N)
:param keys:
:param connections:
:return:
the first two columns are the index of the node
the 3rd column is the weight, and the 4th column is the enabled status
"""
indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij')
indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
# make (2, N, N) to (N, N, 2)
con = jnp.transpose(connections, (1, 2, 0))
# make (N, N, 2) to (N * N, 2)
con = jnp.reshape(con, (-1, 2))
con = jnp.concatenate((indices, con), axis=1)
return con
@partial(jit, static_argnames=['N'])
def unflatten_connections(N, cons):
"""
restore the (N * N, 4) connections to (2, N, N)
:param N:
:param cons: :param cons:
:param nodes:
:return: :return:
""" """
cons = cons[:, 2:] # remove the indices N = nodes.shape[0]
unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0) node_keys = nodes[:, 0]
return unflatten_cons i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
return res
@jit @partial(vmap, in_axes=(0, None))
def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]: def key_to_indices(key, keys):
""" return fetch_first(key == keys)
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
intersection mask, and union mask.
:param ar1: JAX array of shape (N, M)
First input array. Should have the same shape as ar2.
:param ar2: JAX array of shape (N, M)
Second input array. Should have the same shape as ar1.
:return: tuple of 3 arrays
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
in the sorted concatenation.
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
in the sorted concatenation.
Examples:
a = jnp.array([[1, 2], [3, 4], [5, 6]])
b = jnp.array([[1, 2], [7, 8], [9, 10]])
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
sorted_indices -> array([0, 1, 2, 3, 4, 5])
intersect_mask -> array([True, False, False, False, False, False])
union_mask -> array([False, True, True, True, True, True])
"""
ar = jnp.concatenate((ar1, ar2), axis=0)
sorted_indices = jnp.lexsort(ar.T[::-1])
aux = ar[sorted_indices]
aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0)
nan_mask = jnp.any(jnp.isnan(aux), axis=1)
fr, sr = aux[:-1], aux[1:] # first row, second row
intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1]
union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1]
return sorted_indices, intersect_mask, union_mask
@jit @jit

View File

@@ -8,6 +8,7 @@ from .species import SpeciesController
from .genome import expand, expand_single from .genome import expand, expand_single
from .function_factory import FunctionFactory from .function_factory import FunctionFactory
from .genome.genome import count from .genome.genome import count
from .genome.debug.tools import check_array_valid
class Pipeline: class Pipeline:
""" """
@@ -23,6 +24,7 @@ class Pipeline:
self.config = config self.config = config
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
self.C = config.basic.init_maximum_connections
self.expand_coe = config.basic.expands_coe self.expand_coe = config.basic.expands_coe
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
@@ -57,6 +59,8 @@ class Pipeline:
self.update_next_generation(winner_part, loser_part, elite_mask) self.update_next_generation(winner_part, loser_part, elite_mask)
# pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation, self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation,
self.o2o_distance, self.o2m_distance) self.o2o_distance, self.o2m_distance)
@@ -105,16 +109,25 @@ class Pipeline:
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections lpc) # new pop nodes, new pop connections
# for i in range(self.pop_size):
# n, c = np.array(npn[i]), np.array(npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# mutate # mutate
new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# for i in range(self.pop_size):
# n, c = np.array(m_npn[i]), np.array(m_npc[i])
# check_array_valid(n, c, self.input_idx, self.output_idx)
# elitism don't mutate # elitism don't mutate
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn) self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn)
self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc) self.pop_connections = np.where(elite_mask[:, None, None], npc, m_npc)
def expand(self): def expand(self):
""" """
@@ -128,20 +141,38 @@ class Pipeline:
max_node_size = np.max(pop_node_sizes) max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N: if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe) self.N = int(self.N * self.expand_coe)
print(f"expand to {self.N}!") print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N) self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
# don't forget to expand representation genome in species # don't forget to expand representation genome in species
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N) s.representative = expand_single(*s.representative, self.N, self.C)
# update functions # update functions
self.compile_functions(debug=True) self.compile_functions(debug=True)
pop_con_keys = self.pop_connections[:, :, 0]
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C:
self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N, self.C)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N, self.C)
# update functions
self.compile_functions(debug=True)
def compile_functions(self, debug=False): def compile_functions(self, debug=False):
self.mutate_func = self.function_factory.create_mutate(self.N) self.mutate_func = self.function_factory.create_mutate(self.N, self.C)
self.crossover_func = self.function_factory.create_crossover(self.N) self.crossover_func = self.function_factory.create_crossover(self.N, self.C)
self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N) self.o2o_distance, self.o2m_distance = self.function_factory.create_distance(self.N, self.C)
def default_analysis(self, fitnesses): def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)

View File

@@ -0,0 +1,49 @@
import jax
import numpy as np
from algorithms.neat.function_factory import FunctionFactory
from algorithms.neat.genome.debug.tools import check_array_valid
from utils import Configer
from algorithms.neat.genome.crossover import crossover
if __name__ == '__main__':
config = Configer.load_config()
function_factory = FunctionFactory(config, debug=True)
initialize_func = function_factory.create_initialize()
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
key = jax.random.PRNGKey(0)
new_node_idx = 100
while True:
key, subkey = jax.random.split(key)
mutate_keys = jax.random.split(subkey, len(pop_nodes))
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
new_node_idx += len(pop_nodes)
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
# for i in range(len(pop_nodes)):
# check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
idx1 = np.random.permutation(len(pop_nodes))
idx2 = np.random.permutation(len(pop_nodes))
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
crossover_keys = jax.random.split(subkey, len(pop_nodes))
# for idx, (zn1, zc1, zn2, zc2) in enumerate(zip(n1, c1, n2, c2)):
# n, c = crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
# try:
# check_array_valid(n, c, input_idx, output_idx)
# except AssertionError as e:
# crossover(crossover_keys[idx], zn1, zc1, zn2, zc2)
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
for i in range(len(pop_nodes)):
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
print(new_node_idx)

View File

@@ -1,11 +1,3 @@
import numpy as np import numpy as np
# 输入 print(np.random.permutation(10))
a = np.array([1, 2, 3, 4])
b = np.array([5, 6])
# 创建一个网格,其中包含所有可能的组合
aa, bb = np.meshgrid(a, b)
aa = aa.flatten()
bb = bb.flatten()
print(aa, bb)

View File

@@ -6,8 +6,8 @@ from time_utils import using_cprofile
from problems import Sin, Xor, DIY from problems import Sin, Xor, DIY
# @using_cprofile @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") # @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
problem = Xor() problem = Xor()

View File

@@ -4,6 +4,7 @@
"num_outputs": 1, "num_outputs": 1,
"problem_batch": 4, "problem_batch": 4,
"init_maximum_nodes": 10, "init_maximum_nodes": 10,
"init_maximum_connections": 10,
"expands_coe": 2, "expands_coe": 2,
"pre_compile_times": 3, "pre_compile_times": 3,
"forward_way": "pop_batch" "forward_way": "pop_batch"
@@ -13,7 +14,7 @@
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": -0.001, "fitness_threshold": -0.001,
"generation_limit": 1000, "generation_limit": 1000,
"pop_size": 1000, "pop_size": 5000,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -57,9 +58,9 @@
"compatibility_weight_coefficient": 0.5, "compatibility_weight_coefficient": 0.5,
"single_structural_mutation": "False", "single_structural_mutation": "False",
"conn_add_prob": 0.5, "conn_add_prob": 0.5,
"conn_delete_prob": 0, "conn_delete_prob": 0.5,
"node_add_prob": 0.2, "node_add_prob": 0.2,
"node_delete_prob": 0 "node_delete_prob": 0.2
}, },
"species": { "species": {
"compatibility_threshold": 2.5, "compatibility_threshold": 2.5,