debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -134,5 +134,3 @@ def act(idx, z):
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
vectorized_act = jax.vmap(act, in_axes=(0, 0))

View File

@@ -48,7 +48,7 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
@@ -64,7 +64,17 @@ def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
@partial(vmap, in_axes=(0, 0))
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
@@ -74,7 +84,7 @@ def homologous_node_distance(n1, n2):
return d
@partial(vmap, in_axes=(0, 0))
@jit
def homologous_connection_distance(c1, c2):
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight

View File

@@ -95,11 +95,11 @@ def topological_sort_debug(nodes: Array, connections: Array) -> Array:
@vmap
def batch_topological_sort(nodes: Array, connections: Array) -> Array:
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
batch version of topological_sort
:param nodes:
:param connections:
:param pop_nodes:
:param pop_connections:
:return:
"""
return topological_sort(nodes, connections)
@@ -175,17 +175,17 @@ if __name__ == '__main__':
])
connections = jnp.array([
[
[0, 0, 1, 0, jnp.nan],
[0, 0, 1, 1, jnp.nan],
[0, 0, 0, 1, jnp.nan],
[0, 0, 0, 0, jnp.nan],
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[0, 0, 1, 0, jnp.nan],
[0, 0, 1, 1, jnp.nan],
[0, 0, 0, 1, jnp.nan],
[0, 0, 0, 0, jnp.nan],
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]

View File

@@ -386,18 +386,30 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
def nothing():
return nodes, connections
def successful_add_node():
# disable the connection
connections = connections.at[1, from_idx, to_idx].set(False)
new_nodes, new_connections = nodes, connections
new_connections = new_connections.at[1, from_idx, to_idx].set(False)
# add a new node
nodes, connections = add_node(new_node_key, nodes, connections,
new_nodes, new_connections = \
add_node(new_node_key, new_nodes, new_connections,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(nodes[:, 0] == new_node_key)
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
# add two new connections
weight = connections[0, from_idx, to_idx]
nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True)
nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True)
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections
@@ -482,7 +494,15 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections)
def nothing():
return nodes, connections
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections
@@ -530,6 +550,10 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
return from_key, to_key, from_idx, to_idx

View File

@@ -0,0 +1,5 @@
from .genome import create_initialize_function, expand, expand_single
from .distance import distance
from .mutate import create_mutate_function
from .forward import create_forward_function
from .crossover import batch_crossover

View File

@@ -0,0 +1,113 @@
import numpy as np
def sigmoid_act(z):
z = np.clip(z * 5, -60, 60)
return 1 / (1 + np.exp(-z))
def tanh_act(z):
z = np.clip(z * 2.5, -60, 60)
return np.tanh(z)
def sin_act(z):
z = np.clip(z * 5, -60, 60)
return np.sin(z)
def gauss_act(z):
z = np.clip(z, -3.4, 3.4)
return np.exp(-5 * z ** 2)
def relu_act(z):
return np.maximum(z, 0)
def elu_act(z):
return np.where(z > 0, z, np.exp(z) - 1)
def lelu_act(z):
leaky = 0.005
return np.where(z > 0, z, leaky * z)
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return np.where(z > 0, lam * z, lam * alpha * (np.exp(z) - 1))
def softplus_act(z):
z = np.clip(z * 5, -60, 60)
return 0.2 * np.log(1 + np.exp(z))
def identity_act(z):
return z
def clamped_act(z):
return np.clip(z, -1, 1)
def inv_act(z):
return 1 / z
def log_act(z):
z = np.maximum(z, 1e-7)
return np.log(z)
def exp_act(z):
z = np.clip(z, -60, 60)
return np.exp(z)
def abs_act(z):
return np.abs(z)
def hat_act(z):
return np.maximum(0, 1 - np.abs(z))
def square_act(z):
return z ** 2
def cube_act(z):
return z ** 3
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
def act(idx, z):
idx = np.asarray(idx, dtype=np.int32)
return ACT_TOTAL_LIST[idx](z)

View File

@@ -0,0 +1,86 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import numpy as np
def sum_agg(z):
z = np.where(np.isnan(z), 0, z)
return np.sum(z, axis=0)
def product_agg(z):
z = np.where(np.isnan(z), 1, z)
return np.prod(z, axis=0)
def max_agg(z):
z = np.where(np.isnan(z), -np.inf, z)
return np.max(z, axis=0)
def min_agg(z):
z = np.where(np.isnan(z), np.inf, z)
return np.min(z, axis=0)
def maxabs_agg(z):
z = np.where(np.isnan(z), 0, z)
abs_z = np.abs(z)
max_abs_index = np.argmax(abs_z)
return z[max_abs_index]
def median_agg(z):
non_zero_mask = ~np.isnan(z)
n = np.sum(non_zero_mask, axis=0)
z = np.where(np.isnan(z), np.inf, z)
sorted_valid_values = np.sort(z)
if n % 2 == 0:
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
else:
return sorted_valid_values[n // 2]
def mean_agg(z):
non_zero_mask = ~np.isnan(z)
valid_values_sum = sum_agg(z)
valid_values_count = np.sum(non_zero_mask, axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
def agg(idx, z):
idx = np.asarray(idx, dtype=np.int32)
if np.all(z == 0.):
return 0
else:
return AGG_TOTAL_LIST[idx](z)
if __name__ == '__main__':
array = np.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=np.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = np.asarray([0, 0, 0, 0], dtype=np.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))

View File

@@ -0,0 +1,90 @@
from typing import Tuple
import numpy as np
from numpy.typing import NDArray
from .utils import flatten_connections, unflatten_connections
def batch_crossover(batch_nodes1: NDArray, batch_connections1: NDArray, batch_nodes2: NDArray,
batch_connections2: NDArray) -> Tuple[NDArray, NDArray]:
"""
crossover a batch of genomes
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
res_nodes, res_cons = [], []
for (n1, c1, n2, c2) in zip(batch_nodes1, batch_connections1, batch_nodes2, batch_connections2):
new_nodes, new_cons = crossover(n1, c1, n2, c2)
res_nodes.append(new_nodes)
res_cons.append(new_cons)
return np.stack(res_nodes, axis=0), np.stack(res_cons, axis=0)
def crossover(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) \
-> Tuple[NDArray, NDArray]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param nodes1:
:param connections1:
:param nodes2:
:param connections2:
:return:
"""
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = np.where(np.isnan(nodes1) | np.isnan(nodes2), nodes1, crossover_gene(nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = np.where(np.isnan(cons1) | np.isnan(cons2), cons1, crossover_gene(cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons
def align_array(seq1: NDArray, seq2: NDArray, ar2: NDArray, gene_type: str) -> NDArray:
"""
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, np.newaxis], seq2[np.newaxis, :]
mask = (seq1 == seq2) & (~np.isnan(seq1))
if gene_type == 'connection':
mask = np.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = np.arange(0, len(seq1))
idx_fixed = np.dot(mask, idx)
refactor_ar2 = np.where(intersect_mask[:, np.newaxis], ar2[idx_fixed], np.nan)
return refactor_ar2
def crossover_gene(g1: NDArray, g2: NDArray) -> NDArray:
"""
crossover two genes
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = np.random.rand()
return np.where(r > 0.5, g1, g2)

View File

@@ -0,0 +1,94 @@
from functools import partial
import numpy as np
from numpy.typing import NDArray
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
def distance(nodes1: NDArray, connections1: NDArray, nodes2: NDArray, connections2: NDArray) -> NDArray:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
node_distance = gene_distance(nodes1, nodes2, 'node')
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
max_cnt = np.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
def batch_homologous_node_distance(b_n1, b_n2):
res = []
for n1, n2 in zip(b_n1, b_n2):
d = homologous_node_distance(n1, n2)
res.append(d)
return np.stack(res, axis=0)
def batch_homologous_connection_distance(b_c1, b_c2):
res = []
for c1, c2 in zip(b_c1, b_c2):
d = homologous_connection_distance(c1, c2)
res.append(d)
return np.stack(res, axis=0)
def homologous_node_distance(n1, n2):
d = 0
d += np.abs(n1[1] - n2[1]) # bias
d += np.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
return d
def homologous_connection_distance(c1, c2):
d = 0
d += np.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -0,0 +1,151 @@
from functools import partial
import numpy as np
from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
:param nodes: shape (N, 5) or (pop_size, N, 5)
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
:param N:
:param input_idx:
:param output_idx:
:param batch: using batch or not
:param debug: debug mode
:return:
"""
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
else:
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
elif nodes.ndim == 3: # pop genome
pop_cal_seqs = batch_topological_sort(nodes, connections)
if not batch:
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
def forward_single(inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray,
cal_seqs: NDArray, nodes: NDArray, connections: NDArray) -> NDArray:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
ini_vals = np.full((N,), np.nan)
ini_vals[input_idx] = inputs
for i in cal_seqs:
if i in input_idx:
continue
if i == I_INT:
break
ins = ini_vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
ini_vals[i] = z
return ini_vals[output_idx]
def forward_batch(batch_inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray,
cal_seqs: NDArray, nodes: NDArray, connections: NDArray) -> NDArray:
"""
jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (batch_size, output_num)
"""
res = []
for inputs in batch_inputs:
out = forward_single(inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
res.append(out)
return np.stack(res, axis=0)
def pop_forward_single(inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray,
pop_cal_seqs: NDArray, pop_nodes: NDArray, pop_connections: NDArray) -> NDArray:
"""
jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, output_num)
"""
res = []
for cal_seqs, nodes, connections in zip(pop_cal_seqs, pop_nodes, pop_connections):
out = forward_single(inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
res.append(out)
return np.stack(res, axis=0)
def pop_forward_batch(batch_inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDArray,
pop_cal_seqs: NDArray, pop_nodes: NDArray, pop_connections: NDArray) -> NDArray:
"""
jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, batch_size, output_num)
"""
res = []
for cal_seqs, nodes, connections in zip(pop_cal_seqs, pop_nodes, pop_connections):
out = forward_batch(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
res.append(out)
return np.stack(res, axis=0)

View File

@@ -0,0 +1,270 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: NDArray, connections: NDArray] to encode the genome, where:
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes
too large to be represented by the current value of N.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg).
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled
status.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple, Dict
from functools import partial
import numpy as np
from numpy.typing import NDArray
from algorithms.neat.genome.utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
def create_initialize_function(config):
pop_size = config.neat.population.pop_size
N = config.basic.init_maximum_nodes
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean
# default_act = config.neat.gene.activation.default
# default_agg = config.neat.gene.aggregation.default
default_act = 0
default_agg = 0
default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight)
def initialize_genomes(pop_size: int,
N: int,
num_inputs: int, num_outputs: int,
default_bias: float = 0.0,
default_response: float = 1.0,
default_act: int = 0,
default_agg: int = 0,
default_weight: float = 1.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network.
num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
Raises:
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_connections = np.full((pop_size, 2, N, N), np.nan)
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = default_bias
pop_nodes[:, output_idx, 2] = default_response
pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg
for i in input_idx:
for j in output_idx:
pop_connections[:, 0, i, j] = default_weight
pop_connections[:, 1, i, j] = 1
return pop_nodes, pop_connections, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
"""
Expand the genome to accommodate more nodes.
:param pop_nodes: (pop_size, N, 5)
:param pop_connections: (pop_size, 2, N, N)
:param new_N:
:return:
"""
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan)
new_pop_connections[:, :, :old_N, :old_N] = pop_connections
return new_pop_nodes, new_pop_connections
def expand_single(nodes: NDArray, connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_N:
:return:
"""
old_N = nodes.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_connections = np.full((2, new_N, new_N), np.nan)
new_connections[:, :old_N, :old_N] = connections
return new_nodes, new_connections
def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param connections: (2, N, N)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], connections_dict[(f_key, t_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
bias = node[1] if not np.isnan(node[1]) else None
response = node[2] if not np.isnan(node[2]) else None
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
idx2key[i] = key
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
bias, response, act, agg = nodes_dict[i]
assert bias is None and response is None and act is None and agg is None, \
f"Input node {i} must has None bias, response, act, or agg!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
for k, v in nodes_dict.items():
if k not in input_keys:
bias, response, act, agg = v
assert bias is not None and response is not None and act is not None and agg is not None, \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
assert i in idx2key, f"Node index {i} not found in idx2key:{idx2key}!"
assert j in idx2key, f"Node index {j} not found in idx2key:{idx2key}!"
key = (idx2key[i], idx2key[j])
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
except AssertionError:
print(nodes)
print(connections)
raise AssertionError
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
res.append(analysis(nodes, connections, input_keys, output_keys))
return res
def add_node(new_node_key: int, nodes: NDArray, connections: NDArray,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
"""
add a new node to the genome.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(np.isnan(exist_keys))
nodes[idx] = np.array([new_node_key, bias, response, act, agg])
return nodes, connections
def delete_node(node_key: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(idx, nodes, connections)
def delete_node_by_idx(idx: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
nodes[idx] = EMPTY_NODE
return nodes, connections
def add_connection(from_node: int, to_node: int, nodes: NDArray, connections: NDArray,
weight: float = 0.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
"""
add a new connection to the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: NDArray, connections: NDArray,
weight: float = 0.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
"""
add a new connection to the genome.
"""
connections[:, from_idx, to_idx] = np.array([weight, enabled])
return nodes, connections
def delete_connection(from_node: int, to_node: int, nodes: NDArray, connections: NDArray) -> Tuple[NDArray, NDArray]:
"""
delete a connection from the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: NDArray, connections: NDArray) -> Tuple[
NDArray, NDArray]:
"""
delete a connection from the genome.
"""
connections[:, from_idx, to_idx] = np.nan
return nodes, connections

View File

@@ -0,0 +1,163 @@
"""
Some graph algorithms implemented in jax.
Only used in feed-forward networks.
"""
import numpy as np
from numpy.typing import NDArray
# from .utils import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT
def topological_sort(nodes: NDArray, connections: NDArray) -> NDArray:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = np.array([
[0],
[1],
[2],
[3]
])
connections = np.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
in_degree = np.where(np.isnan(nodes[:, 0]), np.nan, np.sum(connections_enable, axis=0))
res = np.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res[idx] = i
idx += 1
in_degree[i] = -1
children = connections_enable[i, :]
in_degree = np.where(children, in_degree - 1, in_degree)
return res
def batch_topological_sort(pop_nodes: NDArray, pop_connections: NDArray) -> NDArray:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
res = []
for nodes, connections in zip(pop_nodes, pop_connections):
seq = topological_sort(nodes, connections)
res.append(seq)
return np.stack(res, axis=0)
def check_cycles(nodes: NDArray, connections: NDArray, from_idx: NDArray, to_idx: NDArray) -> NDArray:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = np.array([
[0],
[1],
[2],
[3]
])
connections = np.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = ~np.isnan(connections[0, :, :])
connections_enable[from_idx, to_idx] = True
nodes_visited = np.full(nodes.shape[0], False)
nodes_visited[to_idx] = True
for _ in range(nodes_visited.shape[0]):
new_visited = np.dot(nodes_visited, connections_enable)
nodes_visited = np.logical_or(nodes_visited, new_visited)
return nodes_visited[from_idx]
if __name__ == '__main__':
nodes = np.array([
[0],
[1],
[2],
[3],
[np.nan]
])
connections = np.array([
[
[np.nan, np.nan, 1, np.nan, np.nan],
[np.nan, np.nan, 1, 1, np.nan],
[np.nan, np.nan, np.nan, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[np.nan, np.nan, 1, np.nan, np.nan],
[np.nan, np.nan, 1, 1, np.nan],
[np.nan, np.nan, np.nan, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
]
)
print(topological_sort(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))

View File

@@ -0,0 +1,531 @@
from typing import Tuple
from functools import partial
import numpy as np
from numpy.typing import NDArray
from numpy.random import rand
from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
def create_mutate_function(config, input_keys, output_keys, batch: bool):
"""
create mutate function for different situations
:param output_keys:
:param input_keys:
:param config:
:param batch: mutate for population or not
:return:
"""
bias = config.neat.gene.bias
bias_default = bias.init_mean
bias_mean = bias.init_mean
bias_std = bias.init_stdev
bias_mutate_strength = bias.mutate_power
bias_mutate_rate = bias.mutate_rate
bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
response_default = response.init_mean
response_mean = response.init_mean
response_std = response.init_stdev
response_mutate_strength = response.mutate_power
response_mutate_rate = response.mutate_rate
response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
weight_mean = weight.init_mean
weight_std = weight.init_stdev
weight_mutate_strength = weight.mutate_power
weight_mutate_rate = weight.mutate_rate
weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
# act_default = activation.default
act_default = 0
act_range = len(activation.options)
act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
# agg_default = aggregation.default
agg_default = 0
agg_range = len(aggregation.options)
agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
add_node_rate = genome.node_add_prob
delete_node_rate = genome.node_delete_prob
add_connection_rate = genome.conn_add_prob
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
mutate_func = lambda nodes, connections, new_node_key: \
mutate(nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
if not batch:
return mutate_func
else:
def batch_mutate_func(pop_nodes, pop_connections, new_node_keys):
res_nodes, res_connections = [], []
for nodes, connections, new_node_key in zip(pop_nodes, pop_connections, new_node_keys):
nodes, connections = mutate_func(nodes, connections, new_node_key)
res_nodes.append(nodes)
res_connections.append(connections)
return np.stack(res_nodes, axis=0), np.stack(res_connections, axis=0)
return batch_mutate_func
def mutate(nodes: NDArray,
connections: NDArray,
new_node_key: int,
input_keys: NDArray,
output_keys: NDArray,
bias_default: float = 0,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_default: float = 1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_range: int = 5,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_range: int = 5,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True):
"""
:param output_keys:
:param input_keys:
:param agg_default:
:param act_default:
:param response_default:
:param bias_default:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param bias_mean:
:param bias_std:
:param bias_mutate_strength:
:param bias_mutate_rate:
:param bias_replace_rate:
:param response_mean:
:param response_std:
:param response_mutate_strength:
:param response_mutate_rate:
:param response_replace_rate:
:param weight_mean:
:param weight_std:
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_range:
:param act_replace_rate:
:param agg_range:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param single_structure_mutate: a genome is structurally mutate at most once
:return:
"""
# mutate_structure
def nothing(n, c):
return n, c
def m_add_node(n, c):
return mutate_add_node(new_node_key, n, c, bias_default, response_default, act_default, agg_default)
def m_delete_node(n, c):
return mutate_delete_node(n, c, input_keys, output_keys)
def m_add_connection(n, c):
return mutate_add_connection(n, c, input_keys, output_keys)
def m_delete_connection(n, c):
return mutate_delete_connection(n, c)
if single_structure_mutate:
d = np.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
# shorten variable names for beauty
anr, dnr = add_node_rate / d, delete_node_rate / d
acr, dcr = add_connection_rate / d, delete_connection_rate / d
r = rand()
if r <= anr:
nodes, connections = m_add_node(nodes, connections)
elif r <= anr + dnr:
nodes, connections = m_delete_node(nodes, connections)
elif r <= anr + dnr + acr:
nodes, connections = m_add_connection(nodes, connections)
elif r <= anr + dnr + acr + dcr:
nodes, connections = m_delete_connection(nodes, connections)
else:
pass # do nothing
else:
# mutate add node
if rand() < add_node_rate:
nodes, connections = m_add_node(nodes, connections)
# mutate delete node
if rand() < delete_node_rate:
nodes, connections = m_delete_node(nodes, connections)
# mutate add connection
if rand() < add_connection_rate:
nodes, connections = m_add_connection(nodes, connections)
# mutate delete connection
if rand() < delete_connection_rate:
nodes, connections = m_delete_connection(nodes, connections)
nodes, connections = mutate_values(nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
agg_replace_rate, enabled_reverse_rate)
return nodes, connections
def mutate_values(nodes: NDArray,
connections: NDArray,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_range: int = 5,
act_replace_rate: float = 0.1,
agg_range: int = 5,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[NDArray, NDArray]:
"""
Mutate values of nodes and connections.
Args:
nodes: A 2D array representing nodes.
connections: A 3D array representing connections.
bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation.
bias_mutate_rate: Rate of the bias mutation.
bias_replace_rate: Rate of the bias replacement.
response_mean: Mean of the response values.
response_std: Standard deviation of the response values.
response_mutate_strength: Strength of the response mutation.
response_mutate_rate: Rate of the response mutation.
response_replace_rate: Rate of the response replacement.
weight_mean: Mean of the weight values.
weight_std: Standard deviation of the weight values.
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_range: Range of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_range: Range of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
Returns:
A tuple containing mutated nodes and connections.
"""
bias_new = mutate_float_values(nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(connections[0, :, :], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(nodes[:, 3], act_range, act_replace_rate)
agg_new = mutate_int_values(nodes[:, 4], agg_range, agg_replace_rate)
# refactor enabled
r = np.random.rand(*connections[1, :, :].shape)
enabled_new = connections[1, :, :] == 1
enabled_new = np.where(r < enabled_reverse_rate, ~enabled_new, enabled_new)
enabled_new = np.where(~np.isnan(connections[0, :, :]), enabled_new, np.nan)
nodes[:, 1] = bias_new
nodes[:, 2] = response_new
nodes[:, 3] = act_new
nodes[:, 4] = agg_new
connections[0, :, :] = weight_new
connections[1, :, :] = enabled_new
return nodes, connections
def mutate_float_values(old_vals: NDArray, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> NDArray:
"""
Mutate float values of a given array.
Args:
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
noise = np.random.normal(size=old_vals.shape) * mutate_strength
replace = np.random.normal(size=old_vals.shape) * std + mean
r = rand(*old_vals.shape)
new_vals = old_vals
new_vals = np.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = np.where(
np.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
new_vals
)
new_vals = np.where(~np.isnan(old_vals), new_vals, np.nan)
return new_vals
def mutate_int_values(old_vals: NDArray, range: int, replace_rate: float) -> NDArray:
"""
Mutate integer values (act, agg) of a given array.
Args:
old_vals: A 1D array of integer values to be mutated.
range: Range of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
replace_val = np.random.randint(low=0, high=range, size=old_vals.shape)
r = np.random.rand(*old_vals.shape)
new_vals = old_vals
new_vals = np.where(r < replace_rate, replace_val, new_vals)
new_vals = np.where(~np.isnan(old_vals), new_vals, np.nan)
return new_vals
def mutate_add_node(new_node_key: int, nodes: NDArray, connections: NDArray,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[NDArray, NDArray]:
"""
Randomly add a new node from splitting a connection.
:param new_node_key:
:param nodes:
:param connections:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:return:
"""
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(nodes, connections)
def nothing():
return nodes, connections
def successful_add_node():
# disable the connection
new_nodes, new_connections = nodes, connections
new_connections[1, from_idx, to_idx] = False
# add a new node
new_nodes, new_connections = \
add_node(new_node_key, new_nodes, new_connections,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
# add two new connections
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
if from_idx == I_INT:
nodes, connections = nothing()
else:
nodes, connections = successful_add_node()
return nodes, connections
def mutate_delete_node(nodes: NDArray, connections: NDArray,
input_keys: NDArray, output_keys: NDArray) -> Tuple[NDArray, NDArray]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param nodes:
:param connections:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
if np.isnan(node_key):
return nodes, connections
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
# delete connections
aux_connections[:, node_idx, :] = np.nan
aux_connections[:, :, node_idx] = np.nan
return aux_nodes, aux_connections
def mutate_add_connection(nodes: NDArray, connections: NDArray,
input_keys: NDArray, output_keys: NDArray) -> Tuple[NDArray, NDArray]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param nodes:
:param connections:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose two nodes
from_key, from_idx = choice_node_key(nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
to_key, to_idx = choice_node_key(nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
is_already_exist = ~np.isnan(connections[0, from_idx, to_idx])
if is_already_exist:
connections[1, from_idx, to_idx] = True
return nodes, connections
elif check_cycles(nodes, connections, from_idx, to_idx):
return nodes, connections
else:
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections)
return new_nodes, new_connections
def mutate_delete_connection(nodes: NDArray, connections: NDArray):
"""
Randomly delete a connection.
:param nodes:
:param connections:
:return:
"""
from_key, to_key, from_idx, to_idx = choice_connection_key(nodes, connections)
def nothing():
return nodes, connections
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
if from_idx == I_INT:
nodes, connections = nothing()
else:
nodes, connections = successfully_delete_connection()
return nodes, connections
def choice_node_key(nodes: NDArray,
input_keys: NDArray, output_keys: NDArray,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[NDArray, NDArray]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~np.isnan(node_keys)
if not allow_input_keys:
mask = np.logical_and(mask, ~np.isin(node_keys, input_keys))
if not allow_output_keys:
mask = np.logical_and(mask, ~np.isin(node_keys, output_keys))
idx = fetch_random(mask)
if idx == I_INT:
return np.nan, idx
else:
return node_keys[idx], idx
def choice_connection_key(nodes: NDArray, connection: NDArray) -> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Randomly choose a connection key from the given connections.
:param nodes:
:param connection:
:return: from_key, to_key, from_idx, to_idx
"""
has_connections_row = np.any(~np.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(has_connections_row)
if from_idx == I_INT:
return np.nan, np.nan, from_idx, I_INT
col = connection[0, from_idx, :]
to_idx = fetch_random(~np.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = np.where(from_idx != I_INT, from_key, np.nan)
to_key = np.where(to_idx != I_INT, to_key, np.nan)
return from_key, to_key, from_idx, to_idx

View File

@@ -0,0 +1,128 @@
from functools import partial
from typing import Tuple
import numpy as np
from numpy.typing import NDArray
I_INT = np.iinfo(np.int32).max # infinite int
def flatten_connections(keys, connections):
"""
flatten the (2, N, N) connections to (N * N, 4)
:param keys:
:param connections:
:return:
the first two columns are the index of the node
the 3rd column is the weight, and the 4th column is the enabled status
"""
indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij')
indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
# make (2, N, N) to (N, N, 2)
con = np.transpose(connections, (1, 2, 0))
# make (N, N, 2) to (N * N, 2)
con = np.reshape(con, (-1, 2))
con = np.concatenate((indices, con), axis=1)
return con
def unflatten_connections(N, cons):
"""
restore the (N * N, 4) connections to (2, N, N)
:param N:
:param cons:
:return:
"""
cons = cons[:, 2:] # remove the indices
unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0)
return unflatten_cons
def set_operation_analysis(ar1: NDArray, ar2: NDArray) -> Tuple[NDArray, NDArray, NDArray]:
"""
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
intersection mask, and union mask.
:param ar1: JAX array of shape (N, M)
First input array. Should have the same shape as ar2.
:param ar2: JAX array of shape (N, M)
Second input array. Should have the same shape as ar1.
:return: tuple of 3 arrays
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
in the sorted concatenation.
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
in the sorted concatenation.
Examples:
a = np.array([[1, 2], [3, 4], [5, 6]])
b = np.array([[1, 2], [7, 8], [9, 10]])
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
sorted_indices -> array([0, 1, 2, 3, 4, 5])
intersect_mask -> array([True, False, False, False, False, False])
union_mask -> array([False, True, True, True, True, True])
"""
ar = np.concatenate((ar1, ar2), axis=0)
sorted_indices = np.lexsort(ar.T[::-1])
aux = ar[sorted_indices]
aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0)
nan_mask = np.any(np.isnan(aux), axis=1)
fr, sr = aux[:-1], aux[1:] # first row, second row
intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1]
union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1]
return sorted_indices, intersect_mask, union_mask
def fetch_first(mask, default=I_INT) -> NDArray:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = np.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
"""
idx = np.argmax(mask)
return np.where(mask[idx], idx, default)
def fetch_last(mask, default=I_INT) -> NDArray:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return np.where(reversed_idx == default, default, mask.shape[0] - reversed_idx - 1)
def fetch_random(mask, default=I_INT) -> NDArray:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = np.sum(mask)
if true_cnt == 0:
return default
cumsum = np.cumsum(mask)
target = np.random.randint(1, true_cnt + 1, size=())
return fetch_first(cumsum >= target, default)
if __name__ == '__main__':
a = np.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
print(t, fetch_random(a > t))

View File

@@ -117,10 +117,12 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
return fetch_first(cumsum >= target, default)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
@@ -129,6 +131,9 @@ if __name__ == '__main__':
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for _ in range(100):
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(fetch_random(rand_key, a > 0))
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))

View File

@@ -1,15 +1,12 @@
from typing import List, Union, Tuple, Callable
import time
import jax
import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome.crossover import crossover
from .genome import expand, expand_single
from algorithms.neat.genome.genome import pop_analysis, analysis
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
from .genome.numpy import batch_crossover
from .genome.numpy import expand, expand_single
class Pipeline:
@@ -18,7 +15,7 @@ class Pipeline:
"""
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config
self.N = config.basic.init_maximum_nodes
@@ -53,14 +50,6 @@ class Pipeline:
def tell(self, fitnesses):
self.generation += 1
for i, f in enumerate(fitnesses):
if np.isnan(f):
print("fuck!!!!!!!!!!!!!!")
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
np.save('error_nodes.npy', error_nodes)
np.save('error_connections.npy', error_connections)
assert False
self.species_controller.update_species_fitnesses(fitnesses)
crossover_pair = self.species_controller.reproduce(self.generation)
@@ -96,8 +85,6 @@ class Pipeline:
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
# crossover
# prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False)
@@ -112,18 +99,13 @@ class Pipeline:
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
# mutate
new_node_keys = np.array(self.fetch_new_node_keys())
mutate_rand_keys = jax.random.split(k2, self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
@@ -181,20 +163,3 @@ class Pipeline:
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
# pop_nodes, pop_connections = [], []
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# pop_nodes.append(new_nodes)
# pop_connections.append(new_connections)
# try:
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
# except AssertionError:
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# return np.stack(pop_nodes), np.stack(pop_connections)
# return batch_crossover(*args)
def crossover_wrapper(*args):
return batch_crossover(*args)

View File

@@ -1,10 +1,9 @@
from typing import List, Tuple, Dict, Union
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
from .genome.numpy import distance
class Species(object):
@@ -46,10 +45,6 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
# self.o2o_distance_func = np_distance # one to one
self.o2o_distance_func = distance
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
@@ -67,8 +62,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_wrapper(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
@@ -81,9 +75,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
self.o2m_distance_wrapper(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
)
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
@@ -110,7 +102,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
self.o2o_distance_wrapper(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -267,36 +259,6 @@ class SpeciesController:
return crossover_pair
def o2m_distance_wrapper(self, r_nodes, r_connections, pop_nodes, pop_connections):
# distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
# for d in distances:
# if np.isnan(d):
# print("fuck!!!!!!!!!!!!!!")
# print(distances)
# assert False
# return distances
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
if np.isnan(d) or d > 20:
np.save("too_large_distance_r_nodes.npy", r_nodes)
np.save("too_large_distance_r_connections.npy", r_connections)
np.save("too_large_distance_nodes", nodes)
np.save("too_large_distance_connections.npy", connections)
d = self.o2o_distance_func(r_nodes, r_connections, nodes, connections)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
# print(distances)
return distances
def o2o_distance_wrapper(self, *keys):
d = self.o2o_distance_func(*keys)
if np.isnan(d):
print("fuck!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
assert False
return d
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
@@ -351,3 +313,12 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections)
distances.append(d)
distances = np.stack(distances, axis=0)
return distances

View File

@@ -1,5 +0,0 @@
"""
numpy version of functions in genome
"""
from .distance import distance
from .utils import *

View File

@@ -1,58 +0,0 @@
import numpy as np
from .utils import flatten_connections, set_operation_analysis
def distance(nodes1, connections1, nodes2, connections2):
node_distance = gene_distance(nodes1, nodes2, 'node')
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return val / np.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
def homologous_node_distance(n1, n2):
d = 0
d += np.abs(n1[:, 1] - n2[:, 1]) # bias
d += np.abs(n1[:, 2] - n2[:, 2]) # response
d += n1[:, 3] != n2[:, 3] # activation
d += n1[:, 4] != n2[:, 4]
return d
def homologous_connection_distance(c1, c2):
d = 0
d += np.abs(c1[:, 2] - c2[:, 2]) # weight
d += c1[:, 3] != c2[:, 3] # enable
return d

View File

@@ -1,55 +0,0 @@
import numpy as np
I_INT = np.iinfo(np.int32).max # infinite int
def flatten_connections(keys, connections):
indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij')
indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
# make (2, N, N) to (N, N, 2)
con = np.transpose(connections, (1, 2, 0))
# make (N, N, 2) to (N * N, 2)
con = np.reshape(con, (-1, 2))
con = np.concatenate((indices, con), axis=1)
return con
def unflatten_connections(N, cons):
cons = cons[:, 2:] # remove the indices
unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0)
return unflatten_cons
def set_operation_analysis(ar1, ar2):
ar = np.concatenate((ar1, ar2), axis=0)
sorted_indices = np.lexsort(ar.T[::-1])
aux = ar[sorted_indices]
aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0)
nan_mask = np.any(np.isnan(aux), axis=1)
fr, sr = aux[:-1], aux[1:] # first row, second row
intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1]
union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1]
return sorted_indices, intersect_mask, union_mask
def fetch_first(mask, default=I_INT):
idx = np.argmax(mask)
return np.where(mask[idx], idx, default)
def fetch_last(mask, default=I_INT):
reversed_idx = fetch_first(mask[::-1], default)
return np.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
def fetch_random(rand_key, mask, default=I_INT):
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = np.sum(mask)
cumsum = np.cumsum(mask)
target = np.random.randint(rand_key, shape=(), minval=0, maxval=true_cnt + 1)
return fetch_first(cumsum >= target, default)