change project structure and using .ini as config file

This commit is contained in:
wls2002
2023-06-15 11:05:26 +08:00
parent 47fb0151f4
commit acedd67617
30 changed files with 97 additions and 301 deletions

9
neat/genome/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .genome import expand, expand_single, initialize_genomes
from .forward import forward_single
from .activations import act_name2key
from .aggregations import agg_name2key
from .crossover import crossover
from .mutate import mutate
from .distance import distance
from .graph import topological_sort
from .utils import unflatten_connections

140
neat/genome/activations.py Normal file
View File

@@ -0,0 +1,140 @@
import jax
import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@jit
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@jit
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@jit
def relu_act(z):
return jnp.maximum(z, 0)
@jit
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z):
return z
@jit
def clamped_act(z):
return jnp.clip(z, -1, 1)
@jit
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@jit
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@jit
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@jit
def abs_act(z):
return jnp.abs(z)
@jit
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z):
return z ** 2
@jit
def cube_act(z):
return z ** 3
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)

109
neat/genome/aggregations.py Normal file
View File

@@ -0,0 +1,109 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
@jit
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@jit
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@jit
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@jit
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@jit
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
z = jnp.where(jnp.isnan(z), jnp.inf, z)
sorted_valid_values = jnp.sort(z)
def _even_case():
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
def _odd_case():
return sorted_valid_values[n // 2]
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
return median
@jit
def mean_agg(z):
non_zero_mask = ~jnp.isnan(z)
valid_values_sum = sum_agg(z)
valid_values_count = jnp.sum(non_zero_mask, axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_nan():
return 0.
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))

76
neat/genome/crossover.py Normal file
View File

@@ -0,0 +1,76 @@
from functools import partial
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

View File

@@ -0,0 +1,88 @@
from collections import defaultdict
import numpy as np
def check_array_valid(nodes, cons, input_keys, output_keys):
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
def array2object(nodes, cons, input_keys, output_keys):
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
if key in input_keys:
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
nodes_dict[key] = (None,) * 4
else:
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
bias = node[1]
response = node[2]
act = node[3]
agg = node[4]
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.all(np.isnan(con)):
pass
elif np.all(~np.isnan(con)):
i_key = int(con[0])
o_key = int(con[1])
if (i_key, o_key) in cons_dict:
assert False, f"Duplicate connection: {(i_key, o_key)}!"
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
weight = con[2]
enabled = (con[3] == 1)
cons_dict[(i_key, o_key)] = (weight, enabled)
else:
assert False, f"Connection {i} must has all None or all non-None!"
return nodes_dict, cons_dict
def is_DAG(edges):
all_nodes = set()
for a, b in edges:
if a == b: # cycle
return False
all_nodes.union({a, b})
for node in all_nodes:
visited = {n: False for n in all_nodes}
def dfs(n):
if visited[n]:
return False
visited[n] = True
for a, b in edges:
if a == n:
if not dfs(b):
return False
return True
if not dfs(node):
return False
return True

97
neat/genome/distance.py Normal file
View File

@@ -0,0 +1,97 @@
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
nd = node_distance(nodes1, nodes2, disjoint_coe, compatibility_coe) # node distance
cd = connection_distance(cons1, cons2, disjoint_coe, compatibility_coe) # connection distance
return nd + cd
@jit
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
return d
@jit
def homologous_connection_distance(c1, c2):
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

47
neat/genome/forward.py Normal file
View File

@@ -0,0 +1,47 @@
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from .aggregations import agg
from .activations import act
from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def scan_body(carry, i):
def hit():
ins = carry * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
new_vals = carry.at[i].set(z)
return new_vals
def miss():
return carry
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]

201
neat/genome/genome.py Normal file
View File

@@ -0,0 +1,201 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
too large to be represented by the current value of N and C.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg).
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple, Dict
import jax
import numpy as np
from numpy.typing import NDArray
from jax import numpy as jnp
from jax import jit
from jax import Array
from .utils import fetch_first
def initialize_genomes(pop_size: int,
N: int,
C: int,
num_inputs: int,
num_outputs: int,
default_bias: float = 0.0,
default_response: float = 1.0,
default_act: int = 0,
default_agg: int = 0,
default_weight: float = 0.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
Raises:
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_cons = np.full((pop_size, C, 4), np.nan)
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = default_bias
pop_nodes[:, output_idx, 2] = default_response
pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
pop_cons[:, :num_inputs * num_outputs, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the genome to accommodate more nodes.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return:
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes.
:param nodes: (N, 5)
:param cons: (2, N, N)
:param new_N:
:param new_C:
:return:
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
@jit
def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: Array, cons: Array, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
"""
add a new node to the genome.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
use idx to delete a node from the genome. only delete the node, regardless of connections.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@jit
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
@jit
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
use idx to add a new connection to the genome.
"""
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
use idx to delete a connection from the genome.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

179
neat/genome/graph.py Normal file
View File

@@ -0,0 +1,179 @@
"""
Some graph algorithms implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, connections: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
def scan_body(carry, _):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
def hit():
# add to res and flag it is already in it
new_res = res_.at[idx_].set(i)
new_idx = idx_ + 1
new_in_degree = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
return new_res, new_idx, new_in_degree
def miss():
return res_, idx_, in_degree_
return jax.lax.cond(i == I_INT, miss, hit), None
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
res, _, _ = scan_res
return res
@jit
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
batch version of topological_sort
:param pop_nodes:
:param pop_connections:
:return:
"""
return topological_sort(pop_nodes, pop_connections)
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)
def scan_body(visited, _):
new_visited = jnp.dot(visited, connections_enable)
new_visited = jnp.logical_or(visited, new_visited)
return new_visited, None
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
return nodes_visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))

457
neat/genome/mutate.py Normal file
View File

@@ -0,0 +1,457 @@
from typing import Tuple
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jax import jit, vmap, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
# TODO: Temporally delete single_structural_mutation, for i need to run it as soon as possible.
@jit
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
new_node_key: int,
input_idx: Array,
output_idx: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
):
"""
:param output_idx:
:param input_idx:
:param agg_default:
:param act_default:
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param bias_mean:
:param bias_std:
:param bias_mutate_strength:
:param bias_mutate_rate:
:param bias_replace_rate:
:param response_mean:
:param response_std:
:param response_mutate_strength:
:param response_mutate_rate:
:param response_replace_rate:
:param weight_mean:
:param weight_std:
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_list:
:param act_replace_rate:
:param agg_list:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:return:
"""
def m_add_node(rk, n, c):
return mutate_add_node(rk, n, c, new_node_key, bias_mean, response_mean, act_default, agg_default)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
agg_replace_rate, enabled_reverse_rate)
return nodes, connections
@jit
def mutate_values(rand_key: Array,
nodes: Array,
cons: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation.
bias_mutate_rate: Rate of the bias mutation.
bias_replace_rate: Rate of the bias replacement.
response_mean: Mean of the response values.
response_std: Standard deviation of the response values.
response_mutate_strength: Strength of the response mutation.
response_mutate_rate: Rate of the response mutation.
response_replace_rate: Rate of the response replacement.
weight_mean: Mean of the weight values.
weight_std: Standard deviation of the weight values.
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_list: List of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_list: List of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, cons[:, 2], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# mutate enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < enabled_reverse_rate, 1 - cons[:, 3], cons[:, 3])
enabled_new = jnp.where(~jnp.isnan(cons[:, 3]), enabled_new, jnp.nan)
nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new)
cons = cons.at[:, 2].set(weight_new)
cons = cons.at[:, 3].set(enabled_new)
return nodes, cons
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = jnp.where(
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = \
add_node(new_nodes, new_cons, new_node_key,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Need we really need to delete a node?
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, node_idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == node_key) | (aux_cons[:, 1] == node_key))[:, jnp.newaxis],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, cons
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
unflattened = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, unflattened, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

106
neat/genome/utils.py Normal file
View File

@@ -0,0 +1,106 @@
from functools import partial
from typing import Tuple
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes, cons):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
# (2, N, N), (2, N, N), (2, N, N)
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))