modifying

This commit is contained in:
wls2002
2023-06-27 18:47:47 +08:00
parent ba369db0b2
commit 114ff2b0cc
28 changed files with 451 additions and 123 deletions

View File

@@ -0,0 +1,6 @@
"""
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
"""
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
from .operations import create_next_generation_then_speciate
from .species import SpeciesController

View File

@@ -0,0 +1,7 @@
from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .forward import create_forward
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .genome import initialize_genomes, expand, expand_single

View File

@@ -0,0 +1,126 @@
import jax
import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@jit
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@jit
def gauss_act(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@jit
def relu_act(z):
return jnp.maximum(z, 0)
@jit
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z):
return z
@jit
def clamped_act(z):
return jnp.clip(z, -1, 1)
@jit
def inv_act(z):
z = jnp.maximum(z, 1e-7)
return 1 / z
@jit
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@jit
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@jit
def abs_act(z):
return jnp.abs(z)
@jit
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z):
return z ** 2
@jit
def cube_act(z):
return z ** 3
act_name2func = {
'sigmoid': sigmoid_act,
'tanh': tanh_act,
'sin': sin_act,
'gauss': gauss_act,
'relu': relu_act,
'elu': elu_act,
'lelu': lelu_act,
'selu': selu_act,
'softplus': softplus_act,
'identity': identity_act,
'clamped': clamped_act,
'inv': inv_act,
'log': log_act,
'exp': exp_act,
'abs': abs_act,
'hat': hat_act,
'square': square_act,
'cube': cube_act,
}

View File

@@ -0,0 +1,60 @@
import jax.numpy as jnp
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
def median_agg(z):
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
def mean_agg(z):
non_zero_mask = ~jnp.isnan(z)
valid_values_sum = sum_agg(z)
valid_values_count = jnp.sum(non_zero_mask, axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
agg_name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}

View File

@@ -0,0 +1,81 @@
"""
Crossover two genomes to generate a new genome.
The calculation method is the same as the crossover operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
"""
from typing import Tuple
import jax
from jax import jit, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, nodes2, 'node')
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

View File

@@ -0,0 +1,88 @@
from collections import defaultdict
import numpy as np
def check_array_valid(nodes, cons, input_keys, output_keys):
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
def array2object(nodes, cons, input_keys, output_keys):
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
if key in input_keys:
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
nodes_dict[key] = (None,) * 4
else:
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
bias = node[1]
response = node[2]
act = node[3]
agg = node[4]
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.all(np.isnan(con)):
pass
elif np.all(~np.isnan(con)):
i_key = int(con[0])
o_key = int(con[1])
if (i_key, o_key) in cons_dict:
assert False, f"Duplicate connection: {(i_key, o_key)}!"
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
weight = con[2]
enabled = (con[3] == 1)
cons_dict[(i_key, o_key)] = (weight, enabled)
else:
assert False, f"Connection {i} must has all None or all non-None!"
return nodes_dict, cons_dict
def is_DAG(edges):
all_nodes = set()
for a, b in edges:
if a == b: # cycle
return False
all_nodes.union({a, b})
for node in all_nodes:
visited = {n: False for n in all_nodes}
def dfs(n):
if visited[n]:
return False
visited[n] = True
for a, b in edges:
if a == n:
if not dfs(b):
return False
return True
if not dfs(node):
return False
return True

View File

@@ -0,0 +1,119 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
args:
nodes1: Array(N, 5)
cons1: Array(C, 4)
nodes2: Array(N, 5)
cons2: Array(C, 4)
returns:
distance: Array(, )
"""
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(homologous_node_distance)(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
@jit
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(homologous_connection_distance)(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
'compatibility_weight']
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def homologous_node_distance(n1: Array, n2: Array):
"""
Calculate the distance between two homologous nodes.
"""
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4] # aggregation
return d
@jit
def homologous_connection_distance(c1: Array, c2: Array):
"""
Calculate the distance between two homologous connections.
"""
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -0,0 +1,86 @@
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from .utils import I_INT
def create_forward(config):
"""
meta method to create forward function
"""
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward

View File

@@ -0,0 +1,181 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where:
nodes: [key, bias, response, act, agg]
connections: [in_key, out_key, weight, enable]
N: Maximum number of nodes in the network.
C: Maximum number of connections in the network.
"""
from typing import Tuple, Dict
import numpy as np
from numpy.typing import NDArray
from jax import jit, numpy as jnp
from .utils import fetch_first
def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
config (Dict): Configuration dictionary.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \
f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!"
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
input_idx = config['input_idx']
output_idx = config['output_idx']
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
p = config['num_inputs'] * config['num_outputs']
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes or connections.
:param nodes: (N, 5)
:param cons: (C, 4)
:param new_N:
:param new_C:
:return: (new_N, 5), (new_C, 4)
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the population to accommodate more nodes or connections.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
@jit
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
"""
Count how many nodes and connections are in the genome.
"""
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its key.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its idx.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@jit
def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its input and output node keys.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

View File

@@ -0,0 +1,169 @@
"""
Some graph algorithms implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, Array
from jax import numpy as jnp
# from .configs import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, connections: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))

View File

@@ -0,0 +1,351 @@
"""
Mutate a genome.
The calculation method is the same as the mutation operation in NEAT-python.
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
"""
from typing import Tuple, Dict
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, Array
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
from .graph import check_cycles
@jit
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
"""
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param jit_config:
:return:
"""
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# structural mutations
# mutate add node
r = rand(r1)
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
# mutate add connection
r = rand(r2)
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
# mutate delete node
r = rand(r3)
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
# mutate delete connection
r = rand(r4)
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
# value mutations
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
return nodes, connections
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
cons: A 3D array representing connections.
jit_config: A dict containing configuration for jit-able functions.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
# bias
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
jit_config['bias_replace_rate'])
# response
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
jit_config['response_init_std'], jit_config['response_mutate_power'],
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
# weight
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
jit_config['weight_replace_rate'])
# activation
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
jit_config['activation_replace_rate'])
# aggregation
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
jit_config['aggregation_replace_rate'])
# enabled
r = jax.random.uniform(rand_key, cons[:, 3].shape)
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
# merge
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
return nodes, cons
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
# default
new_vals = old_vals
# r in [0, mutate_rate), mutate
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
# r in [mutate_rate, mutate_rate + replace_rate), replace
new_vals = jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace + new_vals * 0.0, # in case of nan replace to values
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
return new_vals
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing(): # there is no connection to split
return nodes, cons
def successful_add_node():
# disable the connection
new_nodes, new_cons = nodes, cons
# set enable to false
new_cons = new_cons.at[idx, 3].set(False)
# add a new node
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
# add two new connections
w = new_cons[idx, 2]
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
return new_nodes, new_cons
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
return nodes, cons
# TODO: Do we really need to delete a node?
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose a node
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, cons
def successful_delete_node():
# delete the node
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
# delete all connections
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
jnp.nan, aux_cons)
return aux_nodes, aux_cons
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
return nodes, cons
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param cons:
:param jit_config:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
allow_input_keys=False, allow_output_keys=True)
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
def successful():
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
return new_nodes, new_cons
def already_exist():
new_cons = cons.at[con_idx, 3].set(True)
return nodes, new_cons
def cycle():
return nodes, cons
is_already_exist = con_idx != I_INT
u_cons = unflatten_connections(nodes, cons)
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, cons
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param cons:
:return:
"""
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
def nothing():
return nodes, cons
def successfully_delete_connection():
return delete_connection_by_idx(nodes, cons, idx)
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
return nodes, cons
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param cons:
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
return i_key, o_key, idx
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -0,0 +1,70 @@
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array
from jax import jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
return res
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
"""
if reverse:
array = -array
return jnp.argsort(jnp.argsort(array))

View File

@@ -0,0 +1,160 @@
from functools import partial
import jax
from jax import jit, numpy as jnp, vmap
from .genome.utils import rank_elements
@jit
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
"""
args:
randkey: random key
fitness: Array[(pop_size,), float], the fitness of each individual
species_keys: Array[(species_size, 3), float], the information of each species
[species_key, best_score, last_update]
idx2species: Array[(pop_size,), int], map the individual to its species
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
center_cons: Array[(species_size, C, 4), float], the center connections of each species
generation: int, current generation
jit_config: Dict, the configuration of jit functions
"""
# update the fitness of each species
species_fitness = update_species_fitness(species_info, idx2species, fitness)
# stagnation species
species_fitness, species_info, center_nodes, center_cons = \
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
species_info = species_info[sort_indices]
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(species_info, jit_config)
# crossover info
winner, loser, elite_mask = \
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
jax.debug.print("{}, {}", fitness, winner)
jax.debug.print("{}", fitness[winner])
return species_info, center_nodes, center_cons, winner, loser, elite_mask
def update_species_fitness(species_info, idx2species, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
species_key = species_info[idx, 0]
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
species_key, best_score, last_update = species_info[idx]
# stagnation condition
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
# set stagnation species to nan
species_info = jnp.where(st[:, None], jnp.nan, species_info)
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
species_fitness = jnp.where(st, jnp.nan, species_fitness)
return species_fitness, species_info, center_nodes, center_cons
def cal_spawn_numbers(species_info, jit_config):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(species_info[:, 0])
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
# must control the sum of spawn_number to be equal to pop_size
error = jit_config['pop_size'] - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
species_size = species_info.shape[0]
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
def aux_func(key, idx):
members = idx2species == species_info[idx, 0]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, jnp.nan)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = jit_config['genome_elitism']
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask

View File

@@ -0,0 +1,166 @@
"""
contains operations on the population: creating the next generation and population speciation.
"""
import jax
from jax import jit, vmap, Array, numpy as jnp
from .genome import distance, mutate, crossover
from .genome.utils import I_INT, fetch_first
@jit
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
center_nodes, center_cons, species_keys, new_species_key_start,
jit_config):
# create next generation
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
new_node_keys, jit_config)
# speciate
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
@jit
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
# prepare random keys
pop_size = pop_nodes.shape[0]
k1, k2 = jax.random.split(rand_key, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# batch mutation
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
return pop_nodes, pop_cons
@jit
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
"""
args:
pop_nodes: (pop_size, N, 5)
pop_cons: (pop_size, C, 4)
spe_center_nodes: (species_size, N, 5)
spe_center_cons: (species_size, C, 4)
"""
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
s2p_distance_func = vmap(
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
)
# idx to specie key
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
# part 1: find new centers
# the distance between each species' center and each genome in population
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
def find_new_centers(i, carry):
i2s, cn, cc = carry
# find new center
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
# check species[i] exist or not
# if not exist, set idx and i to I_INT, jax will not do array value assignment
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
i2s = i2s.at[idx].set(species_keys[i])
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
return i2s, cn, cc
idx2specie, center_nodes, center_cons = \
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
not_all_assigned = ~jnp.all(i2s != I_INT)
not_reach_species_upper_bounds = i < species_size
return not_all_assigned & not_reach_species_upper_bounds
def body_func(carry):
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
i2s, scn, scc, sk, ck = jax.lax.cond(
sk[i] == I_INT, # whether the current species is existing or not
create_new_specie, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cn, cc, sk, ck)
)
return i + 1, i2s, scn, scc, sk, ck
def create_new_specie(carry):
i, i2s, cn, cc, sk, ck = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(i2s == I_INT)
# assign it to the new species
sk = sk.at[i].set(ck)
i2s = i2s.at[idx].set(ck)
# update center genomes
cn = cn.at[i].set(pop_nodes[idx])
cc = cc.at[i].set(pop_cons[idx])
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cn, cc, sk, ck = carry
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
return i2s, cn, cc, sk, ck
def speciate_by_threshold(carry):
i, i2s, cn, cc, sk = carry
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
return i2s
current_new_key = new_species_key_start
# update idx2specie
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
cond_func,
body_func,
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition seems to be only happened when the number of species is reached species upper bounds
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
return idx2specie, center_nodes, center_cons, species_keys
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

283
algorithms/neat/species.py Normal file
View File

@@ -0,0 +1,283 @@
"""
Species Controller in NEAT.
The code are modified from neat-python.
See
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
"""
from typing import List, Tuple, Dict
import numpy as np
from numpy.typing import NDArray
from .genome.utils import I_INT
class Species(object):
def __init__(self, key, generation):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
self.members: NDArray = None # idx in pop_nodes, pop_connections,
self.fitness = None
self.member_fitnesses = None
self.adjusted_fitness = None
self.fitness_history: List[float] = []
def update(self, representative, members):
self.representative = representative
self.members = members
def get_fitnesses(self, fitnesses):
return fitnesses[self.members]
class SpeciesController:
"""
A class to control the species
"""
def __init__(self, config):
self.config = config
self.species_elitism = self.config['species_elitism']
self.pop_size = self.config['pop_size']
self.max_stagnation = self.config['max_stagnation']
self.min_species_size = self.config['min_species_size']
self.genome_elitism = self.config['genome_elitism']
self.survival_threshold = self.config['survival_threshold']
self.species: Dict[int, Species] = {} # species_id -> species
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
"""
speciate for the first generation
:param pop_connections:
:param pop_nodes:
:return:
"""
pop_size = pop_nodes.shape[0]
species_id = 0 # the first species
s = Species(species_id, 0)
members = np.array(list(range(pop_size)))
s.update((pop_nodes[0], pop_connections[0]), members)
self.species[species_id] = s
def __update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
:return:
"""
for sid, s in self.species.items():
s.member_fitnesses = s.get_fitnesses(fitnesses)
# use the max score to represent the fitness of the species
s.fitness = np.max(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def __stagnation(self, generation):
"""
:param generation:
:return: whether the species is stagnated
"""
species_data = []
for sid, s in self.species.items():
if s.fitness_history:
prev_fitness = max(s.fitness_history)
else:
prev_fitness = float('-inf')
if s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
# Sort in descending fitness order.
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
result = []
for idx, (sid, s) in enumerate(species_data):
if idx < self.species_elitism: # elitism species never stagnate!
is_stagnant = False
else:
stagnant_time = generation - s.last_improved
is_stagnant = stagnant_time > self.max_stagnation
result.append((sid, s, is_stagnant))
return result
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
"""
:param fitnesses:
:param generation:
:return: crossover_pair for next generation.
# int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover
"""
# Filter out stagnated species, collect the set of non-stagnated
# species members, and compute their average adjusted fitness.
# The average adjusted fitness scheme (normalized to the interval
# [0, 1]) allows the use of negative fitness values without
# interfering with the shared fitness scheme.
min_fitness = np.inf
max_fitness = -np.inf
remaining_species = []
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
if not stagnant:
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
remaining_species.append(stag_s)
# No species left.
assert remaining_species
# TODO: Too complex!
# Compute each species' member size in the next generation.
# Do not allow the fitness range to be zero, as we divide by it below.
# TODO: The ``1.0`` below is rather arbitrary, and should be configurable.
fitness_range = max(1.0, max_fitness - min_fitness)
for afs in remaining_species:
# Compute adjusted fitness.
msf = afs.fitness
af = (msf - min_fitness) / fitness_range # make adjusted fitness in [0, 1]
afs.adjusted_fitness = af
adjusted_fitnesses = [s.adjusted_fitness for s in remaining_species]
previous_sizes = [len(s.members) for s in remaining_species]
min_species_size = max(self.min_species_size, self.genome_elitism)
spawn_amounts = compute_spawn(adjusted_fitnesses, previous_sizes, self.pop_size, min_species_size)
assert sum(spawn_amounts) == self.pop_size
# generate new population and speciate
self.species = {}
# int -> idx in the pop_nodes, pop_connections of elitism
# (int, int) -> the father and mother idx to be crossover
part1, part2, elite_mask = [], [], []
for spawn, s in zip(spawn_amounts, remaining_species):
assert spawn >= self.genome_elitism
# retain remain species to next generation
old_members, member_fitnesses = s.members, s.member_fitnesses
s.members = []
self.species[s.key] = s
# add elitism genomes to next generation
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses)
if self.genome_elitism > 0:
for m in sorted_members[:self.genome_elitism]:
part1.append(m)
part2.append(m)
elite_mask.append(True)
spawn -= 1
if spawn <= 0:
continue
# add genome to be crossover to next generation
repro_cutoff = int(np.ceil(self.survival_threshold * len(sorted_members)))
repro_cutoff = max(repro_cutoff, 2)
# only use good genomes to crossover
sorted_members = sorted_members[:repro_cutoff]
# TODO: Genome with higher fitness should be more likely to be selected?
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
part1.extend(sorted_members[list_idx1])
part2.extend(sorted_members[list_idx2])
elite_mask.extend([False] * spawn)
part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2]
is_part1_win = part1_fitness >= part2_fitness
winner_part = np.where(is_part1_win, part1, part2)
loser_part = np.where(is_part1_win, part2, part1)
return winner_part, loser_part, np.array(elite_mask)
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
for idx, key in enumerate(species_keys):
if key == I_INT:
continue
members = np.where(idx2specie == key)[0]
assert len(members) > 0
if key not in self.species:
# the new specie created in this generation
s = Species(key, generation)
self.species[key] = s
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
def ask(self, fitnesses, generation, symbols):
self.__update_species_fitnesses(fitnesses)
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
species_keys = np.full((symbols['S'], ), I_INT)
for idx, (key, specie) in enumerate(self.species.items()):
center_nodes[idx], center_cons[idx] = specie.representative
species_keys[idx] = key
next_new_specie_key = max(self.species.keys()) + 1
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
"""
Code from neat-python, the only modification is to fix the population size for each generation.
Compute the proper number of offspring per species (proportional to fitness).
"""
af_sum = sum(adjusted_fitness)
spawn_amounts = []
for af, ps in zip(adjusted_fitness, previous_sizes):
if af_sum > 0:
s = max(min_species_size, af / af_sum * pop_size)
else:
s = min_species_size
d = (s - ps) * 0.5
c = int(round(d))
spawn = ps
if abs(c) > 0:
spawn += c
elif d > 0:
spawn += 1
elif d < 0:
spawn -= 1
spawn_amounts.append(spawn)
# Normalize the spawn amounts so that the next generation is roughly
# the population size requested by the user.
total_spawn = sum(spawn_amounts)
norm = pop_size / total_spawn
spawn_amounts = [max(min_species_size, int(round(n * norm))) for n in spawn_amounts]
# for batch parallelization, pop size must be a fixed value.
total_amounts = sum(spawn_amounts)
spawn_amounts[0] += pop_size - total_amounts
assert sum(spawn_amounts) == pop_size, "Population size is not stable."
return spawn_amounts
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
-> Tuple[NDArray, NDArray]:
sorted_idx = np.argsort(fitnesses)[::-1]
return members[sorted_idx], fitnesses[sorted_idx]