modifying
This commit is contained in:
0
algorithms/__init__.py
Normal file
0
algorithms/__init__.py
Normal file
6
algorithms/neat/__init__.py
Normal file
6
algorithms/neat/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
contains operations on a single genome. e.g. forward, mutate, crossover, etc.
|
||||
"""
|
||||
from .genome import create_forward, topological_sort, unflatten_connections, initialize_genomes, expand, expand_single
|
||||
from .operations import create_next_generation_then_speciate
|
||||
from .species import SpeciesController
|
||||
7
algorithms/neat/genome/__init__.py
Normal file
7
algorithms/neat/genome/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .forward import create_forward
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
126
algorithms/neat/genome/activations.py
Normal file
126
algorithms/neat/genome/activations.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
|
||||
@jit
|
||||
def sigmoid_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 1 / (1 + jnp.exp(-z))
|
||||
|
||||
|
||||
@jit
|
||||
def tanh_act(z):
|
||||
z = jnp.clip(z * 2.5, -60, 60)
|
||||
return jnp.tanh(z)
|
||||
|
||||
|
||||
@jit
|
||||
def sin_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return jnp.sin(z)
|
||||
|
||||
|
||||
@jit
|
||||
def gauss_act(z):
|
||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
||||
return jnp.exp(-z ** 2)
|
||||
|
||||
|
||||
@jit
|
||||
def relu_act(z):
|
||||
return jnp.maximum(z, 0)
|
||||
|
||||
|
||||
@jit
|
||||
def elu_act(z):
|
||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||
|
||||
|
||||
@jit
|
||||
def lelu_act(z):
|
||||
leaky = 0.005
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
|
||||
@jit
|
||||
def selu_act(z):
|
||||
lam = 1.0507009873554804934193349852946
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||
|
||||
|
||||
@jit
|
||||
def softplus_act(z):
|
||||
z = jnp.clip(z * 5, -60, 60)
|
||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||
|
||||
|
||||
@jit
|
||||
def identity_act(z):
|
||||
return z
|
||||
|
||||
|
||||
@jit
|
||||
def clamped_act(z):
|
||||
return jnp.clip(z, -1, 1)
|
||||
|
||||
|
||||
@jit
|
||||
def inv_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
|
||||
@jit
|
||||
def log_act(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
|
||||
@jit
|
||||
def exp_act(z):
|
||||
z = jnp.clip(z, -60, 60)
|
||||
return jnp.exp(z)
|
||||
|
||||
|
||||
@jit
|
||||
def abs_act(z):
|
||||
return jnp.abs(z)
|
||||
|
||||
|
||||
@jit
|
||||
def hat_act(z):
|
||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||
|
||||
|
||||
@jit
|
||||
def square_act(z):
|
||||
return z ** 2
|
||||
|
||||
|
||||
@jit
|
||||
def cube_act(z):
|
||||
return z ** 3
|
||||
|
||||
|
||||
act_name2func = {
|
||||
'sigmoid': sigmoid_act,
|
||||
'tanh': tanh_act,
|
||||
'sin': sin_act,
|
||||
'gauss': gauss_act,
|
||||
'relu': relu_act,
|
||||
'elu': elu_act,
|
||||
'lelu': lelu_act,
|
||||
'selu': selu_act,
|
||||
'softplus': softplus_act,
|
||||
'identity': identity_act,
|
||||
'clamped': clamped_act,
|
||||
'inv': inv_act,
|
||||
'log': log_act,
|
||||
'exp': exp_act,
|
||||
'abs': abs_act,
|
||||
'hat': hat_act,
|
||||
'square': square_act,
|
||||
'cube': cube_act,
|
||||
}
|
||||
60
algorithms/neat/genome/aggregations.py
Normal file
60
algorithms/neat/genome/aggregations.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
def median_agg(z):
|
||||
non_nan_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_nan_mask, axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
|
||||
def mean_agg(z):
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
valid_values_sum = sum_agg(z)
|
||||
valid_values_count = jnp.sum(non_zero_mask, axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
agg_name2func = {
|
||||
'sum': sum_agg,
|
||||
'product': product_agg,
|
||||
'max': max_agg,
|
||||
'min': min_agg,
|
||||
'maxabs': maxabs_agg,
|
||||
'median': median_agg,
|
||||
'mean': mean_agg,
|
||||
}
|
||||
81
algorithms/neat/genome/crossover.py
Normal file
81
algorithms/neat/genome/crossover.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Crossover two genomes to generate a new genome.
|
||||
The calculation method is the same as the crossover operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.configure_crossover
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import jax
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
||||
@jit
|
||||
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
:param randkey:
|
||||
:param nodes1:
|
||||
:param cons1:
|
||||
:param nodes2:
|
||||
:param cons2:
|
||||
:return:
|
||||
"""
|
||||
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param gene_type:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if gene_type == 'connection':
|
||||
mask = jnp.all(mask, axis=2)
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
:param rand_key:
|
||||
:param g1:
|
||||
:param g2:
|
||||
:return:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
0
algorithms/neat/genome/debug/__init__.py
Normal file
0
algorithms/neat/genome/debug/__init__.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
88
algorithms/neat/genome/debug/tools.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_array_valid(nodes, cons, input_keys, output_keys):
|
||||
nodes_dict, cons_dict = array2object(nodes, cons, input_keys, output_keys)
|
||||
# assert is_DAG(cons_dict.keys()), "The genome is not a DAG!"
|
||||
|
||||
|
||||
def array2object(nodes, cons, input_keys, output_keys):
|
||||
"""
|
||||
Convert a genome from array to dict.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
|
||||
"""
|
||||
# update nodes_dict
|
||||
nodes_dict = {}
|
||||
for i, node in enumerate(nodes):
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
key = int(node[0])
|
||||
assert key not in nodes_dict, f"Duplicate node key: {key}!"
|
||||
|
||||
if key in input_keys:
|
||||
assert np.all(np.isnan(node[1:])), f"Input node {key} must has None bias, response, act, or agg!"
|
||||
nodes_dict[key] = (None,) * 4
|
||||
else:
|
||||
assert np.all(~np.isnan(node[1:])), f"Normal node {key} must has non-None bias, response, act, or agg!"
|
||||
bias = node[1]
|
||||
response = node[2]
|
||||
act = node[3]
|
||||
agg = node[4]
|
||||
nodes_dict[key] = (bias, response, act, agg)
|
||||
|
||||
# check nodes_dict
|
||||
for i in input_keys:
|
||||
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
|
||||
|
||||
for o in output_keys:
|
||||
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
|
||||
|
||||
# update connections
|
||||
cons_dict = {}
|
||||
for i, con in enumerate(cons):
|
||||
if np.all(np.isnan(con)):
|
||||
pass
|
||||
elif np.all(~np.isnan(con)):
|
||||
i_key = int(con[0])
|
||||
o_key = int(con[1])
|
||||
if (i_key, o_key) in cons_dict:
|
||||
assert False, f"Duplicate connection: {(i_key, o_key)}!"
|
||||
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
|
||||
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
|
||||
weight = con[2]
|
||||
enabled = (con[3] == 1)
|
||||
cons_dict[(i_key, o_key)] = (weight, enabled)
|
||||
else:
|
||||
assert False, f"Connection {i} must has all None or all non-None!"
|
||||
|
||||
return nodes_dict, cons_dict
|
||||
|
||||
|
||||
def is_DAG(edges):
|
||||
all_nodes = set()
|
||||
for a, b in edges:
|
||||
if a == b: # cycle
|
||||
return False
|
||||
all_nodes.union({a, b})
|
||||
|
||||
for node in all_nodes:
|
||||
visited = {n: False for n in all_nodes}
|
||||
def dfs(n):
|
||||
if visited[n]:
|
||||
return False
|
||||
visited[n] = True
|
||||
for a, b in edges:
|
||||
if a == n:
|
||||
if not dfs(b):
|
||||
return False
|
||||
return True
|
||||
|
||||
if not dfs(node):
|
||||
return False
|
||||
return True
|
||||
119
algorithms/neat/genome/distance.py
Normal file
119
algorithms/neat/genome/distance.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
The calculation method is the same as the distance calculation in NEAT-python.
|
||||
See https://github.com/CodeReclaimers/neat-python/blob/master/neat/genome.py
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from .utils import EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@jit
|
||||
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
|
||||
"""
|
||||
Calculate the distance between two genomes.
|
||||
args:
|
||||
nodes1: Array(N, 5)
|
||||
cons1: Array(C, 4)
|
||||
nodes2: Array(N, 5)
|
||||
cons2: Array(C, 4)
|
||||
returns:
|
||||
distance: Array(, )
|
||||
"""
|
||||
nd = node_distance(nodes1, nodes2, jit_config) # node distance
|
||||
cd = connection_distance(cons1, cons2, jit_config) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
@jit
|
||||
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between nodes of two genomes.
|
||||
"""
|
||||
# statistics nodes count of two genomes
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = vmap(homologous_node_distance)(fr, sr)
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
|
||||
|
||||
@jit
|
||||
def connection_distance(cons1: Array, cons2: Array, jit_config: Dict):
|
||||
"""
|
||||
Calculate the distance between connections of two genomes.
|
||||
Similar process as node_distance.
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = vmap(homologous_connection_distance)(fr, sr)
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * jit_config['compatibility_disjoint'] + homologous_distance * jit_config[
|
||||
'compatibility_weight']
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_node_distance(n1: Array, n2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous nodes.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||
d += jnp.abs(n1[2] - n2[2]) # response
|
||||
d += n1[3] != n2[3] # activation
|
||||
d += n1[4] != n2[4] # aggregation
|
||||
return d
|
||||
|
||||
|
||||
@jit
|
||||
def homologous_connection_distance(c1: Array, c2: Array):
|
||||
"""
|
||||
Calculate the distance between two homologous connections.
|
||||
"""
|
||||
d = 0
|
||||
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||
d += c1[3] != c2[3] # enable
|
||||
return d
|
||||
86
algorithms/neat/genome/forward.py
Normal file
86
algorithms/neat/genome/forward.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp
|
||||
from jax import jit, vmap
|
||||
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
def forward(inputs: Array, cal_seqs: Array, nodes: Array, cons: Array) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
|
||||
weights = jnp.where(jnp.isnan(cons[1, :, :]), jnp.nan, cons[0, :, :]) # enabled
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = values * weights[:, i]
|
||||
z = agg(nodes[i, 4], ins) # z = agg(ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
z = act(nodes[i, 3], z) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
return values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
181
algorithms/neat/genome/genome.py
Normal file
181
algorithms/neat/genome/genome.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Vectorization of genome representation.
|
||||
|
||||
Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where:
|
||||
nodes: [key, bias, response, act, agg]
|
||||
connections: [in_key, out_key, weight, enable]
|
||||
N: Maximum number of nodes in the network.
|
||||
C: Maximum number of connections in the network.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from jax import jit, numpy as jnp
|
||||
|
||||
from .utils import fetch_first
|
||||
|
||||
|
||||
def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Initialize genomes with default values.
|
||||
|
||||
Args:
|
||||
N (int): Maximum number of nodes in the network.
|
||||
C (int): Maximum number of connections in the network.
|
||||
config (Dict): Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
|
||||
"""
|
||||
# Reserve one row for potential mutation adding an extra node
|
||||
assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \
|
||||
f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!"
|
||||
|
||||
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
||||
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
||||
|
||||
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
|
||||
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
pop_nodes[:, input_idx, 0] = input_idx
|
||||
pop_nodes[:, output_idx, 0] = output_idx
|
||||
|
||||
# pop_nodes[:, output_idx, 1] = config['bias_init_mean']
|
||||
pop_nodes[:, output_idx, 1] = np.random.normal(loc=config['bias_init_mean'], scale=config['bias_init_std'],
|
||||
size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 2] = np.random.normal(loc=config['response_init_mean'], scale=config['response_init_std'],
|
||||
size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 3] = np.random.choice(config['activation_options'], size=(config['pop_size'], 1))
|
||||
pop_nodes[:, output_idx, 4] = np.random.choice(config['aggregation_options'], size=(config['pop_size'], 1))
|
||||
|
||||
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
||||
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
||||
|
||||
p = config['num_inputs'] * config['num_outputs']
|
||||
pop_cons[:, :p, 0] = grid_a
|
||||
pop_cons[:, :p, 1] = grid_b
|
||||
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
|
||||
size=(config['pop_size'], p))
|
||||
pop_cons[:, :p, 3] = 1
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand a single genome to accommodate more nodes or connections.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (new_N, 5), (new_C, 4)
|
||||
"""
|
||||
old_N, old_C = nodes.shape[0], cons.shape[0]
|
||||
new_nodes = np.full((new_N, 5), np.nan)
|
||||
new_nodes[:old_N, :] = nodes
|
||||
|
||||
new_cons = np.full((new_C, 4), np.nan)
|
||||
new_cons[:old_C, :] = cons
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand the population to accommodate more nodes or connections.
|
||||
:param pop_nodes: (pop_size, N, 5)
|
||||
:param pop_cons: (pop_size, C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
|
||||
"""
|
||||
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
||||
|
||||
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||
|
||||
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
||||
new_pop_cons[:, :old_C, :] = pop_cons
|
||||
|
||||
return new_pop_nodes, new_pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Count how many nodes and connections are in the genome.
|
||||
"""
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
idx = fetch_first(jnp.isnan(exist_keys))
|
||||
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||
Delete the node by its key.
|
||||
"""
|
||||
node_keys = nodes[:, 0]
|
||||
idx = fetch_first(node_keys == node_key)
|
||||
return delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
@jit
|
||||
def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||
Delete the node by its idx.
|
||||
"""
|
||||
nodes = nodes.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int,
|
||||
weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = cons[:, 0]
|
||||
idx = fetch_first(jnp.isnan(con_keys))
|
||||
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
|
||||
return nodes, cons
|
||||
|
||||
|
||||
@jit
|
||||
def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its input and output node keys.
|
||||
"""
|
||||
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
|
||||
@jit
|
||||
def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
cons = cons.at[idx].set(np.nan)
|
||||
return nodes, cons
|
||||
169
algorithms/neat/genome/graph.py
Normal file
169
algorithms/neat/genome/graph.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Some graph algorithms implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
# from .configs import fetch_first, I_INT
|
||||
from algorithms.neat.genome.utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort! that's crazy!
|
||||
:param nodes: nodes array
|
||||
:param connections: connections array
|
||||
:return: topological sorted sequence
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
||||
"""
|
||||
connections_enable = connections[1, :, :] == 1 # forward function. thus use enable
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = connections_enable[i, :]
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
|
||||
:param nodes: JAX array
|
||||
The array of nodes.
|
||||
:param connections: JAX array
|
||||
The array of connections.
|
||||
:param from_idx: int
|
||||
The index of the starting node.
|
||||
:param to_idx: int
|
||||
The index of the ending node.
|
||||
:return: JAX array
|
||||
An array indicating if there is a cycle caused by the new connection.
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
check_cycles(nodes, connections, 3, 2) -> True
|
||||
check_cycles(nodes, connections, 2, 3) -> False
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[jnp.nan]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
],
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
print(check_cycles(nodes, connections, 2, 3))
|
||||
print(check_cycles(nodes, connections, 0, 3))
|
||||
print(check_cycles(nodes, connections, 1, 0))
|
||||
351
algorithms/neat/genome/mutate.py
Normal file
351
algorithms/neat/genome/mutate.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Mutate a genome.
|
||||
The calculation method is the same as the mutation operation in NEAT-python.
|
||||
See https://neat-python.readthedocs.io/en/latest/_modules/genome.html#DefaultGenome.mutate
|
||||
"""
|
||||
from typing import Tuple, Dict
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
|
||||
|
||||
@jit
|
||||
def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, jit_config: Dict):
|
||||
"""
|
||||
:param rand_key:
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
:param new_node_key:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||
|
||||
# structural mutations
|
||||
# mutate add node
|
||||
r = rand(r1)
|
||||
aux_nodes, aux_connections = mutate_add_node(r1, nodes, connections, new_node_key, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate add connection
|
||||
r = rand(r2)
|
||||
aux_nodes, aux_connections = mutate_add_connection(r3, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['conn_add_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_add_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete node
|
||||
r = rand(r3)
|
||||
aux_nodes, aux_connections = mutate_delete_node(r2, nodes, connections, jit_config)
|
||||
nodes = jnp.where(r < jit_config['node_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['node_delete_prob'], aux_connections, connections)
|
||||
|
||||
# mutate delete connection
|
||||
r = rand(r4)
|
||||
aux_nodes, aux_connections = mutate_delete_connection(r4, nodes, connections)
|
||||
nodes = jnp.where(r < jit_config['conn_delete_prob'], aux_nodes, nodes)
|
||||
connections = jnp.where(r < jit_config['conn_delete_prob'], aux_connections, connections)
|
||||
|
||||
# value mutations
|
||||
nodes, connections = mutate_values(rand_key, nodes, connections, jit_config)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
|
||||
def mutate_values(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Mutate values of nodes and connections.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
nodes: A 2D array representing nodes.
|
||||
cons: A 3D array representing connections.
|
||||
jit_config: A dict containing configuration for jit-able functions.
|
||||
|
||||
Returns:
|
||||
A tuple containing mutated nodes and connections.
|
||||
"""
|
||||
|
||||
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
|
||||
|
||||
# bias
|
||||
bias_new = mutate_float_values(k1, nodes[:, 1], jit_config['bias_init_mean'], jit_config['bias_init_std'],
|
||||
jit_config['bias_mutate_power'], jit_config['bias_mutate_rate'],
|
||||
jit_config['bias_replace_rate'])
|
||||
|
||||
# response
|
||||
response_new = mutate_float_values(k2, nodes[:, 2], jit_config['response_init_mean'],
|
||||
jit_config['response_init_std'], jit_config['response_mutate_power'],
|
||||
jit_config['response_mutate_rate'], jit_config['response_replace_rate'])
|
||||
|
||||
# weight
|
||||
weight_new = mutate_float_values(k3, cons[:, 2], jit_config['weight_init_mean'], jit_config['weight_init_std'],
|
||||
jit_config['weight_mutate_power'], jit_config['weight_mutate_rate'],
|
||||
jit_config['weight_replace_rate'])
|
||||
|
||||
# activation
|
||||
act_new = mutate_int_values(k4, nodes[:, 3], jit_config['activation_options'],
|
||||
jit_config['activation_replace_rate'])
|
||||
|
||||
# aggregation
|
||||
agg_new = mutate_int_values(k5, nodes[:, 4], jit_config['aggregation_options'],
|
||||
jit_config['aggregation_replace_rate'])
|
||||
|
||||
# enabled
|
||||
r = jax.random.uniform(rand_key, cons[:, 3].shape)
|
||||
enabled_new = jnp.where(r < jit_config['enable_mutate_rate'], 1 - cons[:, 3], cons[:, 3])
|
||||
|
||||
# merge
|
||||
nodes = jnp.column_stack([nodes[:, 0], bias_new, response_new, act_new, agg_new])
|
||||
cons = jnp.column_stack([cons[:, 0], cons[:, 1], weight_new, enabled_new])
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate float values of a given array.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
old_vals: A 1D array of float values to be mutated.
|
||||
mean: Mean of the values.
|
||||
std: Standard deviation of the values.
|
||||
mutate_strength: Strength of the mutation.
|
||||
mutate_rate: Rate of the mutation.
|
||||
replace_rate: Rate of the replacement.
|
||||
|
||||
Returns:
|
||||
A mutated 1D array of float values.
|
||||
"""
|
||||
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
|
||||
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
|
||||
replace = jax.random.normal(k2, old_vals.shape) * std + mean
|
||||
|
||||
r = jax.random.uniform(k3, old_vals.shape)
|
||||
|
||||
# default
|
||||
new_vals = old_vals
|
||||
|
||||
# r in [0, mutate_rate), mutate
|
||||
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
|
||||
|
||||
# r in [mutate_rate, mutate_rate + replace_rate), replace
|
||||
new_vals = jnp.where(
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace + new_vals * 0.0, # in case of nan replace to values
|
||||
new_vals
|
||||
)
|
||||
|
||||
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||
return new_vals
|
||||
|
||||
|
||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate integer values (act, agg) of a given array.
|
||||
|
||||
Args:
|
||||
rand_key: A random key for generating random values.
|
||||
old_vals: A 1D array of integer values to be mutated.
|
||||
val_list: List of the integer values.
|
||||
replace_rate: Rate of the replacement.
|
||||
|
||||
Returns:
|
||||
A mutated 1D array of integer values.
|
||||
"""
|
||||
k1, k2, rand_key = jax.random.split(rand_key, num=3)
|
||||
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
|
||||
r = jax.random.uniform(k2, old_vals.shape)
|
||||
new_vals = jnp.where(r < replace_rate, replace_val + old_vals * 0.0, old_vals) # in case of nan replace to values
|
||||
|
||||
return new_vals
|
||||
|
||||
|
||||
def mutate_add_node(rand_key: Array, nodes: Array, cons: Array, new_node_key: int,
|
||||
jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new node from splitting a connection.
|
||||
:param rand_key:
|
||||
:param new_node_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing(): # there is no connection to split
|
||||
return nodes, cons
|
||||
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_nodes, new_cons = nodes, cons
|
||||
|
||||
# set enable to false
|
||||
new_cons = new_cons.at[idx, 3].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes, new_cons = add_node(new_nodes, new_cons, new_node_key, bias=0, response=1,
|
||||
act=jit_config['activation_default'], agg=jit_config['aggregation_default'])
|
||||
|
||||
# add two new connections
|
||||
w = new_cons[idx, 2]
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, i_key, new_node_key, weight=1, enabled=True)
|
||||
new_nodes, new_cons = add_connection(new_nodes, new_cons, new_node_key, o_key, weight=w, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
# TODO: Do we really need to delete a node?
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a node
|
||||
key, idx = choice_node_key(rand_key, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
def nothing():
|
||||
return nodes, cons
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_cons = delete_node_by_idx(nodes, cons, idx)
|
||||
|
||||
# delete all connections
|
||||
aux_cons = jnp.where(((aux_cons[:, 0] == key) | (aux_cons[:, 1] == key))[:, None],
|
||||
jnp.nan, aux_cons)
|
||||
|
||||
return aux_nodes, aux_cons
|
||||
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config: Dict) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||
cycles are not allowed.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:param jit_config:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose two nodes
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
i_key, from_idx = choice_node_key(k1, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=True, allow_output_keys=True)
|
||||
o_key, to_idx = choice_node_key(k2, nodes, jit_config['input_idx'], jit_config['output_idx'],
|
||||
allow_input_keys=False, allow_output_keys=True)
|
||||
|
||||
con_idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||
|
||||
def successful():
|
||||
new_nodes, new_cons = add_connection(nodes, cons, i_key, o_key, weight=1, enabled=True)
|
||||
return new_nodes, new_cons
|
||||
|
||||
def already_exist():
|
||||
new_cons = cons.at[con_idx, 3].set(True)
|
||||
return nodes, new_cons
|
||||
|
||||
def cycle():
|
||||
return nodes, cons
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
|
||||
u_cons = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, cons: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:return:
|
||||
"""
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = choice_connection_key(rand_key, nodes, cons)
|
||||
|
||||
def nothing():
|
||||
return nodes, cons
|
||||
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(nodes, cons, idx)
|
||||
|
||||
nodes, cons = jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, cons
|
||||
|
||||
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param input_keys:
|
||||
:param output_keys:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
|
||||
|
||||
idx = fetch_random(rand_key, mask)
|
||||
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, cons: Array) -> Tuple[Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:param rand_key:
|
||||
:param nodes:
|
||||
:param cons:
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(rand_key, ~jnp.isnan(cons[:, 0]))
|
||||
i_key = jnp.where(idx != I_INT, cons[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INT, cons[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
|
||||
|
||||
def rand(rand_key):
|
||||
return jax.random.uniform(rand_key, ())
|
||||
70
algorithms/neat/genome/utils.py
Normal file
70
algorithms/neat/genome/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def unflatten_connections(nodes: Array, cons: Array):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INT) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
"""
|
||||
if reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
160
algorithms/neat/jit_species.py
Normal file
160
algorithms/neat/jit_species.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import jit, numpy as jnp, vmap
|
||||
|
||||
from .genome.utils import rank_elements
|
||||
|
||||
|
||||
@jit
|
||||
def update_species(randkey, fitness, species_info, idx2species, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
args:
|
||||
randkey: random key
|
||||
fitness: Array[(pop_size,), float], the fitness of each individual
|
||||
species_keys: Array[(species_size, 3), float], the information of each species
|
||||
[species_key, best_score, last_update]
|
||||
idx2species: Array[(pop_size,), int], map the individual to its species
|
||||
center_nodes: Array[(species_size, N, 4), float], the center nodes of each species
|
||||
center_cons: Array[(species_size, C, 4), float], the center connections of each species
|
||||
generation: int, current generation
|
||||
jit_config: Dict, the configuration of jit functions
|
||||
"""
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = update_species_fitness(species_info, idx2species, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_fitness, species_info, center_nodes, center_cons = \
|
||||
stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config)
|
||||
|
||||
# sort species_info by their fitness. (push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
species_info = species_info[sort_indices]
|
||||
center_nodes, center_cons = center_nodes[sort_indices], center_cons[sort_indices]
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = cal_spawn_numbers(species_info, jit_config)
|
||||
|
||||
# crossover info
|
||||
winner, loser, elite_mask = \
|
||||
create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config)
|
||||
|
||||
jax.debug.print("{}, {}", fitness, winner)
|
||||
jax.debug.print("{}", fitness[winner])
|
||||
|
||||
return species_info, center_nodes, center_cons, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(species_info, idx2species, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
species_key = species_info[idx, 0]
|
||||
s_fitness = jnp.where(idx2species == species_key, fitness, -jnp.inf)
|
||||
f = jnp.max(s_fitness)
|
||||
return f
|
||||
|
||||
return vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(species_fitness, species_info, center_nodes, center_cons, generation, jit_config):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = species_fitness[idx]
|
||||
species_key, best_score, last_update = species_info[idx]
|
||||
# stagnation condition
|
||||
return (s_fitness <= best_score) & (generation - last_update > jit_config['max_stagnation'])
|
||||
|
||||
st = vmap(aux_func)(jnp.arange(species_info.shape[0]))
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
st = jnp.where(species_rank < jit_config['species_elitism'], False, st) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
species_info = jnp.where(st[:, None], jnp.nan, species_info)
|
||||
center_nodes = jnp.where(st[:, None, None], jnp.nan, center_nodes)
|
||||
center_cons = jnp.where(st[:, None, None], jnp.nan, center_cons)
|
||||
species_fitness = jnp.where(st, jnp.nan, species_fitness)
|
||||
|
||||
return species_fitness, species_info, center_nodes, center_cons
|
||||
|
||||
|
||||
def cal_spawn_numbers(species_info, jit_config):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_info[:, 0])
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - jnp.arange(species_info.shape[0]) # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
spawn_number = jnp.floor(spawn_number_rate * jit_config['pop_size']).astype(jnp.int32) # calculate member
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = jit_config['pop_size'] - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(randkey, species_info, idx2species, spawn_number, fitness, jit_config):
|
||||
|
||||
species_size = species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
s_idx = jnp.arange(species_size)
|
||||
p_idx = jnp.arange(pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
members = idx2species == species_info[idx, 0]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, jnp.nan)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
elite_size = jit_config['genome_elitism']
|
||||
survive_size = jnp.floor(jit_config['survival_threshold'] * members_num).astype(jnp.int32)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < elite_size, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
|
||||
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
166
algorithms/neat/operations.py
Normal file
166
algorithms/neat/operations.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
contains operations on the population: creating the next generation and population speciation.
|
||||
"""
|
||||
import jax
|
||||
from jax import jit, vmap, Array, numpy as jnp
|
||||
|
||||
from .genome import distance, mutate, crossover
|
||||
from .genome.utils import I_INT, fetch_first
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation_then_speciate(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys,
|
||||
center_nodes, center_cons, species_keys, new_species_key_start,
|
||||
jit_config):
|
||||
# create next generation
|
||||
pop_nodes, pop_cons = create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask,
|
||||
new_node_keys, jit_config)
|
||||
|
||||
# speciate
|
||||
idx2specie, spe_center_nodes, spe_center_cons, species_keys = \
|
||||
speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config)
|
||||
|
||||
return pop_nodes, pop_cons, idx2specie, spe_center_nodes, spe_center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def create_next_generation(rand_key, pop_nodes, pop_cons, winner, loser, elite_mask, new_node_keys, jit_config):
|
||||
# prepare random keys
|
||||
pop_size = pop_nodes.shape[0]
|
||||
k1, k2 = jax.random.split(rand_key, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn, wpc = pop_nodes[winner], pop_cons[winner] # winner pop nodes, winner pop connections
|
||||
lpn, lpc = pop_nodes[loser], pop_cons[loser] # loser pop nodes, loser pop connections
|
||||
npn, npc = vmap(crossover)(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# batch mutation
|
||||
mutate_func = vmap(mutate, in_axes=(0, 0, 0, 0, None))
|
||||
m_npn, m_npc = mutate_func(mutate_rand_keys, npn, npc, new_node_keys, jit_config) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], npn, m_npn)
|
||||
pop_cons = jnp.where(elite_mask[:, None, None], npc, m_npc)
|
||||
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def speciate(pop_nodes, pop_cons, center_nodes, center_cons, species_keys, new_species_key_start, jit_config):
|
||||
"""
|
||||
args:
|
||||
pop_nodes: (pop_size, N, 5)
|
||||
pop_cons: (pop_size, C, 4)
|
||||
spe_center_nodes: (species_size, N, 5)
|
||||
spe_center_cons: (species_size, C, 4)
|
||||
"""
|
||||
pop_size, species_size = pop_nodes.shape[0], center_nodes.shape[0]
|
||||
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(distance, in_axes=(None, None, 0, 0, None)) # one to population
|
||||
s2p_distance_func = vmap(
|
||||
o2p_distance_func, in_axes=(0, 0, None, None, None) # center to population
|
||||
)
|
||||
|
||||
# idx to specie key
|
||||
idx2specie = jnp.full((pop_size,), I_INT, dtype=jnp.int32) # I_INT means not assigned to any species
|
||||
|
||||
# part 1: find new centers
|
||||
# the distance between each species' center and each genome in population
|
||||
s2p_distance = s2p_distance_func(center_nodes, center_cons, pop_nodes, pop_cons, jit_config)
|
||||
|
||||
def find_new_centers(i, carry):
|
||||
i2s, cn, cc = carry
|
||||
# find new center
|
||||
idx = argmin_with_mask(s2p_distance[i], mask=i2s == I_INT)
|
||||
|
||||
# check species[i] exist or not
|
||||
# if not exist, set idx and i to I_INT, jax will not do array value assignment
|
||||
idx = jnp.where(species_keys[i] != I_INT, idx, I_INT)
|
||||
i = jnp.where(species_keys[i] != I_INT, i, I_INT)
|
||||
|
||||
i2s = i2s.at[idx].set(species_keys[i])
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
return i2s, cn, cc
|
||||
|
||||
idx2specie, center_nodes, center_cons = \
|
||||
jax.lax.fori_loop(0, species_size, find_new_centers, (idx2specie, center_nodes, center_cons))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # sk is short for species_keys, ck is short for current key
|
||||
not_all_assigned = ~jnp.all(i2s != I_INT)
|
||||
not_reach_species_upper_bounds = i < species_size
|
||||
return not_all_assigned & not_reach_species_upper_bounds
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry # scn is short for spe_center_nodes, scc is short for spe_center_cons
|
||||
|
||||
i2s, scn, scc, sk, ck = jax.lax.cond(
|
||||
sk[i] == I_INT, # whether the current species is existing or not
|
||||
create_new_specie, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cn, cc, sk, ck)
|
||||
)
|
||||
|
||||
return i + 1, i2s, scn, scc, sk, ck
|
||||
|
||||
def create_new_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(i2s == I_INT)
|
||||
|
||||
# assign it to the new species
|
||||
sk = sk.at[i].set(ck)
|
||||
i2s = i2s.at[idx].set(ck)
|
||||
|
||||
# update center genomes
|
||||
cn = cn.at[i].set(pop_nodes[idx])
|
||||
cc = cc.at[i].set(pop_cons[idx])
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
return i2s, cn, cc, sk, ck + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cn, cc, sk, ck = carry
|
||||
|
||||
i2s = speciate_by_threshold((i, i2s, cn, cc, sk))
|
||||
|
||||
return i2s, cn, cc, sk, ck
|
||||
|
||||
def speciate_by_threshold(carry):
|
||||
i, i2s, cn, cc, sk = carry
|
||||
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(cn[i], cc[i], pop_nodes, pop_cons, jit_config)
|
||||
close_enough_mask = o2p_distance < jit_config['compatibility_threshold']
|
||||
|
||||
# when it is close enough, assign it to the species, remember not to update genome has already been assigned
|
||||
i2s = jnp.where(close_enough_mask & (i2s == I_INT), sk[i], i2s)
|
||||
return i2s
|
||||
|
||||
current_new_key = new_species_key_start
|
||||
|
||||
# update idx2specie
|
||||
_, idx2specie, center_nodes, center_cons, species_keys, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, idx2specie, center_nodes, center_cons, species_keys, current_new_key)
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition seems to be only happened when the number of species is reached species upper bounds
|
||||
idx2specie = jnp.where(idx2specie == I_INT, species_keys[-1], idx2specie)
|
||||
|
||||
return idx2specie, center_nodes, center_cons, species_keys
|
||||
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
283
algorithms/neat/species.py
Normal file
283
algorithms/neat/species.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Species Controller in NEAT.
|
||||
The code are modified from neat-python.
|
||||
See
|
||||
https://neat-python.readthedocs.io/en/latest/_modules/stagnation.html#DefaultStagnation
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#reproduction
|
||||
https://neat-python.readthedocs.io/en/latest/module_summaries.html#species
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .genome.utils import I_INT
|
||||
|
||||
|
||||
class Species(object):
|
||||
|
||||
def __init__(self, key, generation):
|
||||
self.key = key
|
||||
self.created = generation
|
||||
self.last_improved = generation
|
||||
self.representative: Tuple[NDArray, NDArray] = (None, None) # (center_nodes, center_connections)
|
||||
self.members: NDArray = None # idx in pop_nodes, pop_connections,
|
||||
self.fitness = None
|
||||
self.member_fitnesses = None
|
||||
self.adjusted_fitness = None
|
||||
self.fitness_history: List[float] = []
|
||||
|
||||
def update(self, representative, members):
|
||||
self.representative = representative
|
||||
self.members = members
|
||||
|
||||
def get_fitnesses(self, fitnesses):
|
||||
return fitnesses[self.members]
|
||||
|
||||
|
||||
class SpeciesController:
|
||||
"""
|
||||
A class to control the species
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.species_elitism = self.config['species_elitism']
|
||||
self.pop_size = self.config['pop_size']
|
||||
self.max_stagnation = self.config['max_stagnation']
|
||||
self.min_species_size = self.config['min_species_size']
|
||||
self.genome_elitism = self.config['genome_elitism']
|
||||
self.survival_threshold = self.config['survival_threshold']
|
||||
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray):
|
||||
"""
|
||||
speciate for the first generation
|
||||
:param pop_connections:
|
||||
:param pop_nodes:
|
||||
:return:
|
||||
"""
|
||||
pop_size = pop_nodes.shape[0]
|
||||
species_id = 0 # the first species
|
||||
s = Species(species_id, 0)
|
||||
members = np.array(list(range(pop_size)))
|
||||
|
||||
s.update((pop_nodes[0], pop_connections[0]), members)
|
||||
self.species[species_id] = s
|
||||
|
||||
def __update_species_fitnesses(self, fitnesses):
|
||||
"""
|
||||
update the fitness of each species
|
||||
:param fitnesses:
|
||||
:return:
|
||||
"""
|
||||
for sid, s in self.species.items():
|
||||
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||
# use the max score to represent the fitness of the species
|
||||
s.fitness = np.max(s.member_fitnesses)
|
||||
s.fitness_history.append(s.fitness)
|
||||
s.adjusted_fitness = None
|
||||
|
||||
def __stagnation(self, generation):
|
||||
"""
|
||||
:param generation:
|
||||
:return: whether the species is stagnated
|
||||
"""
|
||||
species_data = []
|
||||
for sid, s in self.species.items():
|
||||
if s.fitness_history:
|
||||
prev_fitness = max(s.fitness_history)
|
||||
else:
|
||||
prev_fitness = float('-inf')
|
||||
|
||||
if s.fitness > prev_fitness:
|
||||
s.last_improved = generation
|
||||
|
||||
species_data.append((sid, s))
|
||||
|
||||
# Sort in descending fitness order.
|
||||
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
|
||||
|
||||
result = []
|
||||
for idx, (sid, s) in enumerate(species_data):
|
||||
|
||||
if idx < self.species_elitism: # elitism species never stagnate!
|
||||
is_stagnant = False
|
||||
else:
|
||||
stagnant_time = generation - s.last_improved
|
||||
is_stagnant = stagnant_time > self.max_stagnation
|
||||
|
||||
result.append((sid, s, is_stagnant))
|
||||
return result
|
||||
|
||||
def __reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]:
|
||||
"""
|
||||
:param fitnesses:
|
||||
:param generation:
|
||||
:return: crossover_pair for next generation.
|
||||
# int -> idx in the pop_nodes, pop_connections of elitism
|
||||
# (int, int) -> the father and mother idx to be crossover
|
||||
"""
|
||||
# Filter out stagnated species, collect the set of non-stagnated
|
||||
# species members, and compute their average adjusted fitness.
|
||||
# The average adjusted fitness scheme (normalized to the interval
|
||||
# [0, 1]) allows the use of negative fitness values without
|
||||
# interfering with the shared fitness scheme.
|
||||
|
||||
min_fitness = np.inf
|
||||
max_fitness = -np.inf
|
||||
|
||||
remaining_species = []
|
||||
for stag_sid, stag_s, stagnant in self.__stagnation(generation):
|
||||
if not stagnant:
|
||||
min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses))
|
||||
max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses))
|
||||
remaining_species.append(stag_s)
|
||||
|
||||
# No species left.
|
||||
assert remaining_species
|
||||
|
||||
|
||||
# TODO: Too complex!
|
||||
# Compute each species' member size in the next generation.
|
||||
|
||||
# Do not allow the fitness range to be zero, as we divide by it below.
|
||||
# TODO: The ``1.0`` below is rather arbitrary, and should be configurable.
|
||||
fitness_range = max(1.0, max_fitness - min_fitness)
|
||||
for afs in remaining_species:
|
||||
# Compute adjusted fitness.
|
||||
msf = afs.fitness
|
||||
af = (msf - min_fitness) / fitness_range # make adjusted fitness in [0, 1]
|
||||
afs.adjusted_fitness = af
|
||||
adjusted_fitnesses = [s.adjusted_fitness for s in remaining_species]
|
||||
previous_sizes = [len(s.members) for s in remaining_species]
|
||||
min_species_size = max(self.min_species_size, self.genome_elitism)
|
||||
spawn_amounts = compute_spawn(adjusted_fitnesses, previous_sizes, self.pop_size, min_species_size)
|
||||
assert sum(spawn_amounts) == self.pop_size
|
||||
|
||||
# generate new population and speciate
|
||||
self.species = {}
|
||||
# int -> idx in the pop_nodes, pop_connections of elitism
|
||||
# (int, int) -> the father and mother idx to be crossover
|
||||
|
||||
part1, part2, elite_mask = [], [], []
|
||||
|
||||
for spawn, s in zip(spawn_amounts, remaining_species):
|
||||
assert spawn >= self.genome_elitism
|
||||
|
||||
# retain remain species to next generation
|
||||
old_members, member_fitnesses = s.members, s.member_fitnesses
|
||||
s.members = []
|
||||
self.species[s.key] = s
|
||||
|
||||
# add elitism genomes to next generation
|
||||
sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses)
|
||||
if self.genome_elitism > 0:
|
||||
for m in sorted_members[:self.genome_elitism]:
|
||||
part1.append(m)
|
||||
part2.append(m)
|
||||
elite_mask.append(True)
|
||||
spawn -= 1
|
||||
|
||||
if spawn <= 0:
|
||||
continue
|
||||
|
||||
# add genome to be crossover to next generation
|
||||
repro_cutoff = int(np.ceil(self.survival_threshold * len(sorted_members)))
|
||||
repro_cutoff = max(repro_cutoff, 2)
|
||||
# only use good genomes to crossover
|
||||
sorted_members = sorted_members[:repro_cutoff]
|
||||
|
||||
# TODO: Genome with higher fitness should be more likely to be selected?
|
||||
list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True)
|
||||
part1.extend(sorted_members[list_idx1])
|
||||
part2.extend(sorted_members[list_idx2])
|
||||
elite_mask.extend([False] * spawn)
|
||||
|
||||
part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2]
|
||||
is_part1_win = part1_fitness >= part2_fitness
|
||||
winner_part = np.where(is_part1_win, part1, part2)
|
||||
loser_part = np.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner_part, loser_part, np.array(elite_mask)
|
||||
|
||||
def tell(self, idx2specie, center_nodes, center_cons, species_keys, generation):
|
||||
for idx, key in enumerate(species_keys):
|
||||
if key == I_INT:
|
||||
continue
|
||||
|
||||
members = np.where(idx2specie == key)[0]
|
||||
assert len(members) > 0
|
||||
|
||||
if key not in self.species:
|
||||
# the new specie created in this generation
|
||||
s = Species(key, generation)
|
||||
self.species[key] = s
|
||||
|
||||
self.species[key].update((center_nodes[idx], center_cons[idx]), members)
|
||||
|
||||
def ask(self, fitnesses, generation, symbols):
|
||||
self.__update_species_fitnesses(fitnesses)
|
||||
|
||||
winner, loser, elite_mask = self.__reproduce(fitnesses, generation)
|
||||
|
||||
center_nodes = np.full((symbols['S'], symbols['N'], 5), np.nan)
|
||||
center_cons = np.full((symbols['S'], symbols['C'], 4), np.nan)
|
||||
species_keys = np.full((symbols['S'], ), I_INT)
|
||||
|
||||
for idx, (key, specie) in enumerate(self.species.items()):
|
||||
center_nodes[idx], center_cons[idx] = specie.representative
|
||||
species_keys[idx] = key
|
||||
|
||||
next_new_specie_key = max(self.species.keys()) + 1
|
||||
|
||||
return winner, loser, elite_mask, center_nodes, center_cons, species_keys, next_new_specie_key
|
||||
|
||||
|
||||
def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size):
|
||||
"""
|
||||
Code from neat-python, the only modification is to fix the population size for each generation.
|
||||
Compute the proper number of offspring per species (proportional to fitness).
|
||||
"""
|
||||
af_sum = sum(adjusted_fitness)
|
||||
|
||||
spawn_amounts = []
|
||||
for af, ps in zip(adjusted_fitness, previous_sizes):
|
||||
if af_sum > 0:
|
||||
s = max(min_species_size, af / af_sum * pop_size)
|
||||
else:
|
||||
s = min_species_size
|
||||
|
||||
d = (s - ps) * 0.5
|
||||
c = int(round(d))
|
||||
spawn = ps
|
||||
if abs(c) > 0:
|
||||
spawn += c
|
||||
elif d > 0:
|
||||
spawn += 1
|
||||
elif d < 0:
|
||||
spawn -= 1
|
||||
|
||||
spawn_amounts.append(spawn)
|
||||
|
||||
# Normalize the spawn amounts so that the next generation is roughly
|
||||
# the population size requested by the user.
|
||||
total_spawn = sum(spawn_amounts)
|
||||
norm = pop_size / total_spawn
|
||||
spawn_amounts = [max(min_species_size, int(round(n * norm))) for n in spawn_amounts]
|
||||
|
||||
# for batch parallelization, pop size must be a fixed value.
|
||||
total_amounts = sum(spawn_amounts)
|
||||
spawn_amounts[0] += pop_size - total_amounts
|
||||
assert sum(spawn_amounts) == pop_size, "Population size is not stable."
|
||||
|
||||
return spawn_amounts
|
||||
|
||||
|
||||
def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \
|
||||
-> Tuple[NDArray, NDArray]:
|
||||
sorted_idx = np.argsort(fitnesses)[::-1]
|
||||
return members[sorted_idx], fitnesses[sorted_idx]
|
||||
Reference in New Issue
Block a user