finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -104,11 +104,23 @@ def cube_act(z):
return z ** 3
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
act_name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}

View File

@@ -1,9 +1,3 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import jax
import jax.numpy as jnp
import numpy as np
@@ -44,19 +38,13 @@ def maxabs_agg(z):
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
z = jnp.where(jnp.isnan(z), jnp.inf, z)
sorted_valid_values = jnp.sort(z)
z = jnp.sort(z) # sort
def _even_case():
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
def _odd_case():
return sorted_valid_values[n // 2]
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@@ -70,25 +58,12 @@ def mean_agg(z):
return mean_without_zeros
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_nan():
return 0.
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))
agg_name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}

View File

@@ -1,14 +1,17 @@
from functools import partial
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import jit, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]:
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
@@ -23,7 +26,11 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
@@ -34,7 +41,6 @@ def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2:
return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
@@ -62,7 +68,6 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes

View File

@@ -1,81 +0,0 @@
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,6 +1,9 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
@@ -9,26 +12,34 @@ from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
@@ -36,19 +47,29 @@ def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
@@ -64,37 +85,34 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1, c2):
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable

View File

@@ -1,119 +0,0 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -2,47 +2,82 @@ import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from .aggregations import agg
from .activations import act
from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
def create_forward(config):
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
:argument inputs: (input_num, )
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
:return (output_num, )
"""
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def all_nan():
return 0.
def scan_body(carry, i):
def hit():
ins = carry * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
new_vals = carry.at[i].set(z)
return new_vals
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def miss():
return carry
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
:return (output_num, )
"""
return vals[output_idx]
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward

View File

@@ -44,10 +44,13 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 2] = config['response_init_mean']
pop_nodes[:, output_idx, 3] = config['activation_default']
pop_nodes[:, output_idx, 4] = config['aggregation_default']
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
@@ -55,7 +58,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
p = config['num_inputs'] * config['num_outputs']
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = config['weight_init_mean']
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons

View File

@@ -8,8 +8,7 @@ from jax import jit, vmap, Array
from jax import numpy as jnp
# from .configs import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT
from .utils import unflatten_connections
from neat.genome.utils import fetch_first, I_INT, unflatten_connections
@jit
@@ -44,49 +43,32 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
def scan_body(carry, _):
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
def hit():
# add to res and flag it is already in it
new_res = res_.at[idx_].set(i)
new_idx = idx_ + 1
new_in_degree = in_degree_.at[i].set(-1)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
return new_res, new_idx, new_in_degree
def miss():
return res_, idx_, in_degree_
return jax.lax.cond(i == I_INT, miss, hit), None
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
res, _, _ = scan_res
# decrease in_degree of all its children
children = connections_enable[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
return topological_sort(pop_nodes, pop_connections)
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
@@ -131,22 +113,26 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 1, 0) -> False
"""
connections = unflatten_connections(nodes, connections)
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)
def scan_body(visited, _):
new_visited = jnp.dot(visited, connections_enable)
new_visited = jnp.logical_or(visited, new_visited)
return new_visited, None
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
return nodes_visited[from_idx]
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':

View File

@@ -1,155 +1,64 @@
from typing import Tuple
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jax import jit, vmap, Array
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
new_node_key: int,
input_idx: Array,
output_idx: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
):
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param output_idx:
:param input_idx:
:param agg_default:
:param act_default:
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param bias_mean:
:param bias_std:
:param bias_mutate_strength:
:param bias_mutate_rate:
:param bias_replace_rate:
:param response_mean:
:param response_std:
:param response_mutate_strength:
:param response_mutate_rate:
:param response_replace_rate:
:param weight_mean:
:param weight_std:
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_list:
:param act_replace_rate:
:param agg_list:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param jit_config:
:return:
"""
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
agg_replace_rate, enabled_reverse_rate)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
@jit
def mutate_values(rand_key: Array,
nodes: Array,
cons: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
@@ -157,56 +66,48 @@ def mutate_values(rand_key: Array,
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation.
bias_mutate_rate: Rate of the bias mutation.
bias_replace_rate: Rate of the bias replacement.
response_mean: Mean of the response values.
response_std: Standard deviation of the response values.
response_mutate_strength: Strength of the response mutation.
response_mutate_rate: Rate of the response mutation.
response_replace_rate: Rate of the response replacement.
weight_mean: Mean of the weight values.
weight_std: Standard deviation of the weight values.
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_list: List of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_list: List of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# mutate enabled
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new)
cons = cons.at[:, 2].set(weight_new)
cons = cons.at[:, 3].set(enabled_new)
return nodes, cons
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
@@ -227,19 +128,26 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
@@ -256,26 +164,20 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
@jit
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:param jit_config:
:return:
"""
# randomly choose a connection
@@ -287,12 +189,13 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = \
add_node(new_nodes, new_cons, new_node_key,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
@@ -306,59 +209,53 @@ def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: in
return nodes, cons
# TODO: Need we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
# TODO: Do we really need to delete a node?
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:param jit_config:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
@@ -375,15 +272,14 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
return nodes, cons
is_already_exist = con_idx != I_INT
unflattened = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
@@ -406,7 +302,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
return nodes, cons
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -435,7 +330,6 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
@@ -452,6 +346,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[A
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -1,355 +0,0 @@
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT
from .genome_ import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@jit
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param jit_config:
:return:
"""
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
return nodes, cons
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Do we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a node
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -1,5 +1,4 @@
from functools import partial
from typing import Tuple
import jax
from jax import numpy as jnp, Array
@@ -11,20 +10,18 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes, cons):
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
@@ -34,8 +31,6 @@ def unflatten_connections(nodes, cons):
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@@ -46,27 +41,12 @@ def fetch_first(mask, default=I_INT) -> Array:
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
@@ -78,27 +58,8 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))
return min_idx

View File

@@ -1,102 +0,0 @@
from functools import partial
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = vmap(fetch_first, in_axes=(0, None))(i_keys, node_keys)
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))