initial commit

This commit is contained in:
wls2002
2023-05-05 14:19:13 +08:00
commit 6faa07f507
43 changed files with 2517 additions and 0 deletions

0
algorithms/__init__.py Normal file
View File

Binary file not shown.

View File

View File

@@ -0,0 +1 @@
from .pipeline import Pipeline

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,4 @@
from .genome import create_initialize_function
from .distance import distance
from .mutate import create_mutate_function
from .forward import create_forward_function

View File

@@ -0,0 +1,138 @@
import jax
import jax.numpy as jnp
from jax import jit
@jit
def sigmoid_act(z):
z = jnp.clip(z * 5, -60, 60)
return 1 / (1 + jnp.exp(-z))
@jit
def tanh_act(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@jit
def sin_act(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@jit
def gauss_act(z):
z = jnp.clip(z, -3.4, 3.4)
return jnp.exp(-5 * z ** 2)
@jit
def relu_act(z):
return jnp.maximum(z, 0)
@jit
def elu_act(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@jit
def lelu_act(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@jit
def selu_act(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@jit
def softplus_act(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@jit
def identity_act(z):
return z
@jit
def clamped_act(z):
return jnp.clip(z, -1, 1)
@jit
def inv_act(z):
return 1 / z
@jit
def log_act(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@jit
def exp_act(z):
z = jnp.clip(z, -60, 60)
return jnp.exp(z)
@jit
def abs_act(z):
return jnp.abs(z)
@jit
def hat_act(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@jit
def square_act(z):
return z ** 2
@jit
def cube_act(z):
return z ** 3
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
vectorized_act = jax.vmap(act, in_axes=(0, 0))

View File

@@ -0,0 +1,109 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
@jit
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@jit
def product_agg(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@jit
def max_agg(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@jit
def min_agg(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@jit
def maxabs_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
z = jnp.where(jnp.isnan(z), jnp.inf, z)
sorted_valid_values = jnp.sort(z)
def _even_case():
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
def _odd_case():
return sorted_valid_values[n // 2]
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
return median
@jit
def mean_agg(z):
non_zero_mask = ~jnp.isnan(z)
valid_values_sum = sum_agg(z)
valid_values_count = jnp.sum(non_zero_mask, axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_zero():
return 0.
def not_full_zero():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))

View File

@@ -0,0 +1,151 @@
from functools import partial
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
"""
crossover a batch of genomes
:param randkeys: batches of random keys
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
@jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param connections1:
:param nodes2:
:param connections2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
new_cons = unflatten_connections(len(keys1), new_cons)
return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
@jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')

View File

@@ -0,0 +1,71 @@
from functools import partial
from jax import jit, vmap, Array
from jax import numpy as jnp
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
@jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
"""
Calculate the distance between two genomes.
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
node_distance = gene_distance(nodes1, nodes2, 'node')
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
@partial(jit, static_argnames=["gene_type"])
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = jnp.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance)
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
@partial(vmap, in_axes=(0, 0))
def homologous_node_distance(n1, n2):
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
return d
@partial(vmap, in_axes=(0, 0))
def homologous_connection_distance(c1, c2):
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -0,0 +1,171 @@
from functools import partial
import jax
from jax import Array, numpy as jnp
from jax import jit, vmap
from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
"""
create forward function for different situations
:param nodes: shape (N, 5) or (pop_size, N, 5)
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
:param N:
:param input_idx:
:param output_idx:
:param batch: using batch or not
:param debug: debug mode
:return:
"""
if debug:
cal_seqs = topological_sort(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
else:
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
elif nodes.ndim == 3: # pop genome
pop_cal_seqs = batch_topological_sort(nodes, connections)
if not batch:
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
pop_cal_seqs, nodes, connections)
else:
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def scan_body(carry, i):
def hit():
ins = carry * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
return new_vals
def miss():
return carry
return jax.lax.cond(i == I_INT, miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (batch_size, output_num)
"""
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, output_num)
"""
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, batch_size, output_num)
"""
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)

View File

@@ -0,0 +1,195 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes
too large to be represented by the current value of N.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg).
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled
status.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple
from functools import partial
import numpy as np
from numpy.typing import NDArray
from jax import numpy as jnp
from jax import jit
from jax import Array
from algorithms.neat.genome.utils import fetch_first, fetch_last
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
def create_initialize_function(config):
pop_size = config.neat.population.pop_size
N = config.basic.init_maximum_nodes
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean
# default_act = config.neat.gene.activation.default
# default_agg = config.neat.gene.aggregation.default
default_act = 0
default_agg = 0
default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight)
def initialize_genomes(pop_size: int,
N: int,
num_inputs: int, num_outputs: int,
default_bias: float = 0.0,
default_response: float = 1.0,
default_act: int = 0,
default_agg: int = 0,
default_weight: float = 1.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network.
num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
Raises:
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_connections = np.full((pop_size, 2, N, N), np.nan)
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = default_bias
pop_nodes[:, output_idx, 2] = default_response
pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg
for i in input_idx:
for j in output_idx:
pop_connections[:, 0, i, j] = default_weight
pop_connections[:, 1, i, j] = 1
return pop_nodes, pop_connections, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
"""
Expand the genome to accommodate more nodes.
:param pop_nodes:
:param pop_connections:
:param new_N:
:return:
"""
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan)
new_pop_connections[:, :, :old_N, :old_N] = pop_connections
return new_pop_nodes, new_pop_connections
@jit
def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
"""
add a new node to the genome.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg]))
return nodes, connections
@jit
def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(idx, nodes, connections)
@jit
def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
# move the last node to the deleted node's position
last_real_idx = fetch_last(~jnp.isnan(node_keys))
nodes = nodes.at[idx].set(nodes[last_real_idx])
nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
return nodes, connections
@jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
@jit
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""
connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled]))
return nodes, connections
@jit
def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
"""
node_keys = nodes[:, 0]
from_idx = fetch_first(node_keys == from_node)
to_idx = fetch_first(node_keys == to_node)
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
@jit
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
"""
connections = connections.at[:, from_idx, to_idx].set(np.nan)
return nodes, connections
# if __name__ == '__main__':
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
# print(pop_nodes, pop_connections)

View File

@@ -0,0 +1,198 @@
"""
Some graph algorithms implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import fetch_first, I_INT
from algorithms.neat.genome.utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, connections: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
:param nodes: nodes array
:param connections: connections array
:return: topological sorted sequence
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
"""
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
def scan_body(carry, _):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
def hit():
# add to res and flag it is already in it
new_res = res_.at[idx_].set(i)
new_idx = idx_ + 1
new_in_degree = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = connections_enable[i, :]
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
return new_res, new_idx, new_in_degree
def miss():
return res_, idx_, in_degree_
return jax.lax.cond(i == I_INT, miss, hit), None
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
res, _, _ = scan_res
return res
# @jit
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res = res.at[idx].set(i)
idx += 1
in_degree = in_degree.at[i].set(-1)
children = connections_enable[i, :]
in_degree = jnp.where(children, in_degree - 1, in_degree)
return res
@vmap
def batch_topological_sort(nodes: Array, connections: Array) -> Array:
"""
batch version of topological_sort
:param nodes:
:param connections:
:return:
"""
return topological_sort(nodes, connections)
@jit
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
connections_enable = connections[1, :, :] == 1
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
nodes_visited = jnp.full(nodes.shape[0], False)
nodes_visited = nodes_visited.at[to_idx].set(True)
def scan_body(visited, _):
new_visited = jnp.dot(visited, connections_enable)
new_visited = jnp.logical_or(visited, new_visited)
return new_visited, None
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
return nodes_visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[0, 0, 1, 0, jnp.nan],
[0, 0, 1, 1, jnp.nan],
[0, 0, 0, 1, jnp.nan],
[0, 0, 0, 0, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[0, 0, 1, 0, jnp.nan],
[0, 0, 1, 1, jnp.nan],
[0, 0, 0, 1, jnp.nan],
[0, 0, 0, 0, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort_debug(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))

View File

@@ -0,0 +1,538 @@
from typing import Tuple
from functools import partial
import jax
from jax import numpy as jnp
from jax import jit, vmap, Array
from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
def create_mutate_function(config, input_keys, output_keys, batch: bool):
"""
create mutate function for different situations
:param output_keys:
:param input_keys:
:param config:
:param batch: mutate for population or not
:return:
"""
bias = config.neat.gene.bias
bias_default = bias.init_mean
bias_mean = bias.init_mean
bias_std = bias.init_stdev
bias_mutate_strength = bias.mutate_power
bias_mutate_rate = bias.mutate_rate
bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
response_default = response.init_mean
response_mean = response.init_mean
response_std = response.init_stdev
response_mutate_strength = response.mutate_power
response_mutate_rate = response.mutate_rate
response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
weight_mean = weight.init_mean
weight_std = weight.init_stdev
weight_mutate_strength = weight.mutate_power
weight_mutate_rate = weight.mutate_rate
weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
# act_default = activation.default
act_default = 0
act_range = len(activation.options)
act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
# agg_default = aggregation.default
agg_default = 0
agg_range = len(aggregation.options)
agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
add_node_rate = genome.node_add_prob
delete_node_rate = genome.node_delete_prob
add_connection_rate = genome.conn_add_prob
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
if not batch:
return lambda rand_key, nodes, connections, new_node_key: \
mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
else:
batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31))
return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \
batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
@partial(jit, static_argnames=["single_structure_mutate"])
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
new_node_key: int,
input_keys: Array,
output_keys: Array,
bias_default: float = 0,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_default: float = 1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_range: int = 5,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_range: int = 5,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
delete_node_rate: float = 0.2,
add_connection_rate: float = 0.4,
delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True):
"""
:param output_keys:
:param input_keys:
:param agg_default:
:param act_default:
:param response_default:
:param bias_default:
:param rand_key:
:param nodes: (N, 5)
:param connections: (2, N, N)
:param new_node_key:
:param bias_mean:
:param bias_std:
:param bias_mutate_strength:
:param bias_mutate_rate:
:param bias_replace_rate:
:param response_mean:
:param response_std:
:param response_mutate_strength:
:param response_mutate_rate:
:param response_replace_rate:
:param weight_mean:
:param weight_std:
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_range:
:param act_replace_rate:
:param agg_range:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
:param delete_node_rate:
:param add_connection_rate:
:param delete_connection_rate:
:param single_structure_mutate: a genome is structurally mutate at most once
:return:
"""
# mutate_structure
def nothing(rk, n, c):
return n, c
def m_add_node(rk, n, c):
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_keys, output_keys)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_keys, output_keys)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
if single_structure_mutate:
r1, r2, rand_key = jax.random.split(rand_key, 3)
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
# shorten variable names for beauty
anr, dnr = add_node_rate / d, delete_node_rate / d
acr, dcr = add_connection_rate / d, delete_connection_rate / d
r = rand(r1)
branch = 0
branch = jnp.where(r <= anr, 1, branch)
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
else:
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate add connection
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
agg_replace_rate, enabled_reverse_rate)
return nodes, connections
@jit
def mutate_values(rand_key: Array,
nodes: Array,
connections: Array,
bias_mean: float = 0,
bias_std: float = 1,
bias_mutate_strength: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_mean: float = 1.,
response_std: float = 0.,
response_mutate_strength: float = 0.,
response_mutate_rate: float = 0.,
response_replace_rate: float = 0.,
weight_mean: float = 0.,
weight_std: float = 1.,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_range: int = 5,
act_replace_rate: float = 0.1,
agg_range: int = 5,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
"""
Mutate values of nodes and connections.
Args:
rand_key: A random key for generating random values.
nodes: A 2D array representing nodes.
connections: A 3D array representing connections.
bias_mean: Mean of the bias values.
bias_std: Standard deviation of the bias values.
bias_mutate_strength: Strength of the bias mutation.
bias_mutate_rate: Rate of the bias mutation.
bias_replace_rate: Rate of the bias replacement.
response_mean: Mean of the response values.
response_std: Standard deviation of the response values.
response_mutate_strength: Strength of the response mutation.
response_mutate_rate: Rate of the response mutation.
response_replace_rate: Rate of the response replacement.
weight_mean: Mean of the weight values.
weight_std: Standard deviation of the weight values.
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_range: Range of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_range: Range of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
Returns:
A tuple containing mutated nodes and connections.
"""
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate)
# refactor enabled
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
enabled_new = connections[1, :, :] == 1
enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new)
enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan)
nodes = nodes.at[:, 1].set(bias_new)
nodes = nodes.at[:, 2].set(response_new)
nodes = nodes.at[:, 3].set(act_new)
nodes = nodes.at[:, 4].set(agg_new)
connections = connections.at[0, :, :].set(weight_new)
connections = connections.at[1, :, :].set(enabled_new)
return nodes, connections
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
Mutate float values of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of float values to be mutated.
mean: Mean of the values.
std: Standard deviation of the values.
mutate_strength: Strength of the mutation.
mutate_rate: Rate of the mutation.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of float values.
"""
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
replace = jax.random.normal(k2, old_vals.shape) * std + mean
r = jax.random.uniform(k3, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = jnp.where(
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
replace,
new_vals
)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
range: Range of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.randint(k1, old_vals.shape, 0, range)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
return new_vals
@jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
"""
Randomly add a new node from splitting a connection.
:param rand_key:
:param new_node_key:
:param nodes:
:param connections:
:param default_bias:
:param default_response:
:param default_act:
:param default_agg:
:return:
"""
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
# disable the connection
connections = connections.at[1, from_idx, to_idx].set(False)
# add a new node
nodes, connections = add_node(new_node_key, nodes, connections,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(nodes[:, 0] == new_node_key)
# add two new connections
weight = connections[0, from_idx, to_idx]
nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True)
nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True)
return nodes, connections
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly delete a node. Input and output nodes are not allowed to be deleted.
:param rand_key:
:param nodes:
:param connections:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose a node
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
# check node_key valid
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
return nodes, connections
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
cycles are not allowed.
:param rand_key:
:param nodes:
:param connections:
:param input_keys:
:param output_keys:
:return:
"""
# randomly choose two nodes
k1, k2 = jax.random.split(rand_key, num=2)
from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
allow_input_keys=True, allow_output_keys=True)
to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=True)
def successful():
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections)
return new_nodes, new_connections
def already_exist():
new_connections = connections.at[1, from_idx, to_idx].set(True)
return nodes, new_connections
def cycle():
return nodes, connections
is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx])
is_cycle = check_cycles(nodes, connections, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful])
return nodes, connections
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
Randomly delete a connection.
:param rand_key:
:param nodes:
:param connections:
:return:
"""
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections)
return nodes, connections
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
:param rand_key:
:param nodes:
:param connection:
:return: from_key, to_key, from_idx, to_idx
"""
k1, k2 = jax.random.split(rand_key, num=2)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(k1, has_connections_row)
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
return from_key, to_key, from_idx, to_idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -0,0 +1,134 @@
from functools import partial
from typing import Tuple
import jax
from jax import numpy as jnp, Array
from jax import jit
I_INT = jnp.iinfo(jnp.int32).max # infinite int
@jit
def flatten_connections(keys, connections):
"""
flatten the (2, N, N) connections to (N * N, 4)
:param keys:
:param connections:
:return:
the first two columns are the index of the node
the 3rd column is the weight, and the 4th column is the enabled status
"""
indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij')
indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
# make (2, N, N) to (N, N, 2)
con = jnp.transpose(connections, (1, 2, 0))
# make (N, N, 2) to (N * N, 2)
con = jnp.reshape(con, (-1, 2))
con = jnp.concatenate((indices, con), axis=1)
return con
@partial(jit, static_argnames=['N'])
def unflatten_connections(N, cons):
"""
restore the (N * N, 4) connections to (2, N, N)
:param N:
:param cons:
:return:
"""
cons = cons[:, 2:] # remove the indices
unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0)
return unflatten_cons
@jit
def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]:
"""
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
intersection mask, and union mask.
:param ar1: JAX array of shape (N, M)
First input array. Should have the same shape as ar2.
:param ar2: JAX array of shape (N, M)
Second input array. Should have the same shape as ar1.
:return: tuple of 3 arrays
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
in the sorted concatenation.
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
in the sorted concatenation.
Examples:
a = jnp.array([[1, 2], [3, 4], [5, 6]])
b = jnp.array([[1, 2], [7, 8], [9, 10]])
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
sorted_indices -> array([0, 1, 2, 3, 4, 5])
intersect_mask -> array([True, False, False, False, False, False])
union_mask -> array([False, True, True, True, True, True])
"""
ar = jnp.concatenate((ar1, ar2), axis=0)
sorted_indices = jnp.lexsort(ar.T[::-1])
aux = ar[sorted_indices]
aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0)
nan_mask = jnp.any(jnp.isnan(aux), axis=1)
fr, sr = aux[:-1], aux[1:] # first row, second row
intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1]
union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1]
return sorted_indices, intersect_mask, union_mask
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
return fetch_first(cumsum >= target, default)
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for _ in range(100):
rand_key, _ = jax.random.split(rand_key)
print(fetch_random(rand_key, a > 0))

View File

@@ -0,0 +1,41 @@
import jax
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config):
self.config = config
self.N = config.basic.init_maximum_nodes
self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
def ask(self, batch: bool):
"""
Create a forward function for the population.
:param batch:
:return:
Algorithm gives the population a forward function, then environment gives back the fitnesses.
"""
func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
batch=batch)
return func
def tell(self, fitnesses):
self.generation += 1
print(type(fitnesses), fitnesses)
self.species_controller.update_species_fitnesses(fitnesses)

190
algorithms/neat/species.py Normal file
View File

@@ -0,0 +1,190 @@
from typing import List, Tuple, Dict
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome import distance
class Species(object):
def __init__(self, key, generation):
self.key = key
self.created = generation
self.last_improved = generation
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
self.members: List[int] = [] # idx in pop_nodes, pop_connections
self.fitness = None
self.member_fitnesses = None
self.adjusted_fitness = None
self.fitness_history: List[float] = []
def update(self, representative, members):
self.representative = representative
self.members = members
def get_fitnesses(self, fitnesses):
return [fitnesses[m] for m in self.members]
class SpeciesController:
"""
A class to control the species
"""
def __init__(self, config):
self.config = config
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
self.species_elitism = self.config.neat.species.species_elitism
self.max_stagnation = self.config.neat.species.max_stagnation
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.genome_to_species: Dict[int, int] = {}
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
# self.o2o_distance_func = np_distance # one to one
self.o2o_distance_func = distance
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
:param pop_connections:
:param generation: use to flag the created time of new species
:return:
"""
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
previous_species_list = list(self.species.keys())
# Find the best representatives for each existing species.
new_representatives = {}
new_members = {}
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances) # fetch the data from gpu
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
new_members[sid] = [min_idx]
unspeciated[min_idx] = False
# Partition population into species based on genetic similarity.
# First, fast match the population to previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
jax.device_get(
[
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
)
]
pop_res_distance = np.stack(res_pop_distance, axis=0).T
for i in range(pop_res_distance.shape[0]):
if not unspeciated[i]:
continue
min_idx = np.argmin(pop_res_distance[i])
min_val = pop_res_distance[i, min_idx]
if min_val <= self.compatibility_threshold:
species_id = previous_species_list[min_idx]
new_members[species_id].append(i)
unspeciated[i] = False
# Second, slowly match the lonely population to new-created species.
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
# the new representatives.
new_species_list = []
for i in range(pop_nodes.shape[0]):
if not unspeciated[i]:
continue
unspeciated[i] = False
if len(new_representatives) != 0:
rid = [new_representatives[sid] for sid in new_representatives] # the representatives of new species
distances = [
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
min_idx = np.argmin(distances)
min_val = distances[min_idx]
if min_val <= self.compatibility_threshold:
species_id = new_species_list[min_idx]
new_members[species_id].append(i)
continue
# create a new species
species_id = next(self.species_idxer)
new_species_list.append(species_id)
new_representatives[species_id] = i
new_members[species_id] = [i]
assert np.all(~unspeciated)
# Update species collection based on new speciation.
self.genome_to_species = {}
for sid, rid in new_representatives.items():
s = self.species.get(sid)
if s is None:
s = Species(sid, generation)
self.species[sid] = s
members = new_members[sid]
for gid in members:
self.genome_to_species[gid] = sid
s.update((pop_nodes[rid], pop_connections[rid]), members)
def update_species_fitnesses(self, fitnesses):
"""
update the fitness of each species
:param fitnesses:
:return:
"""
for sid, s in self.species.items():
# TODO: here use mean to measure the fitness of a species, but it may be other functions
s.member_fitnesses = s.get_fitnesses(fitnesses)
s.fitness = np.mean(s.member_fitnesses)
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
def stagnation(self, generation):
"""
code modified from neat-python!
:param generation:
:return: whether the species is stagnated
"""
species_data = []
for sid, s in self.species.items():
if s.fitness_history:
prev_fitness = max(s.fitness_history)
else:
prev_fitness = float('-inf')
if prev_fitness is None or s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
# Sort in descending fitness order.
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
result = []
for idx, (sid, s) in enumerate(species_data):
if idx < self.species_elitism: # elitism species never stagnate!
is_stagnant = False
else:
stagnant_time = generation - s.last_improved
is_stagnant = stagnant_time > self.max_stagnation
result.append((sid, s, is_stagnant))
return result
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
masked_arr = np.where(mask, arr, np.inf)
min_idx = np.argmin(masked_arr)
return min_idx

View File

@@ -0,0 +1,62 @@
"""
Code modified from NEAT-Python library
Keeps track of whether species are making progress and helps remove those which are not.
"""
class Stagnation:
"""Keeps track of whether species are making progress and helps remove ones that are not."""
def __init__(self, config):
self.config = config
def update(self, species_set, generation):
"""
Required interface method. Updates species fitness history information,
checking for ones that have not improved in max_stagnation generations,
and - unless it would result in the number of species dropping below the configured
species_elitism parameter if they were removed,
in which case the highest-fitness species are spared -
returns a list with stagnant species marked for removal.
"""
species_data = []
for sid, s in species_set.species.items():
if s.fitness_history:
prev_fitness = max(s.fitness_history)
else:
prev_fitness = float('-inf')
s.fitness = max(s.get_fitnesses())
s.fitness_history.append(s.fitness)
s.adjusted_fitness = None
if prev_fitness is None or s.fitness > prev_fitness:
s.last_improved = generation
species_data.append((sid, s))
# Sort in ascending fitness order.
species_data.sort(key=lambda x: x[1].fitness)
result = []
species_fitnesses = []
num_non_stagnant = len(species_data)
for idx, (sid, s) in enumerate(species_data):
# Override stagnant state if marking this species as stagnant would
# result in the total number of species dropping below the limit.
# Because species are in ascending fitness order, less fit species
# will be marked as stagnant first.
stagnant_time = generation - s.last_improved
is_stagnant = False
if num_non_stagnant > self.config.stagnation.species_elitism:
is_stagnant = stagnant_time >= self.config.stagnation.max_stagnation
if (len(species_data) - idx) <= self.config.stagnation.species_elitism:
is_stagnant = False
if is_stagnant:
num_non_stagnant -= 1
result.append((sid, s, is_stagnant))
species_fitnesses.append(s.fitness)
return result

View File

@@ -0,0 +1,5 @@
"""
numpy version of functions in genome
"""
from .distance import distance
from .utils import *

View File

@@ -0,0 +1,58 @@
import numpy as np
from .utils import flatten_connections, set_operation_analysis
def distance(nodes1, connections1, nodes2, connections2):
node_distance = gene_distance(nodes1, nodes2, 'node')
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return val / np.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
def homologous_node_distance(n1, n2):
d = 0
d += np.abs(n1[:, 1] - n2[:, 1]) # bias
d += np.abs(n1[:, 2] - n2[:, 2]) # response
d += n1[:, 3] != n2[:, 3] # activation
d += n1[:, 4] != n2[:, 4]
return d
def homologous_connection_distance(c1, c2):
d = 0
d += np.abs(c1[:, 2] - c2[:, 2]) # weight
d += c1[:, 3] != c2[:, 3] # enable
return d

55
algorithms/numpy/utils.py Normal file
View File

@@ -0,0 +1,55 @@
import numpy as np
I_INT = np.iinfo(np.int32).max # infinite int
def flatten_connections(keys, connections):
indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij')
indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
# make (2, N, N) to (N, N, 2)
con = np.transpose(connections, (1, 2, 0))
# make (N, N, 2) to (N * N, 2)
con = np.reshape(con, (-1, 2))
con = np.concatenate((indices, con), axis=1)
return con
def unflatten_connections(N, cons):
cons = cons[:, 2:] # remove the indices
unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0)
return unflatten_cons
def set_operation_analysis(ar1, ar2):
ar = np.concatenate((ar1, ar2), axis=0)
sorted_indices = np.lexsort(ar.T[::-1])
aux = ar[sorted_indices]
aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0)
nan_mask = np.any(np.isnan(aux), axis=1)
fr, sr = aux[:-1], aux[1:] # first row, second row
intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1]
union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1]
return sorted_indices, intersect_mask, union_mask
def fetch_first(mask, default=I_INT):
idx = np.argmax(mask)
return np.where(mask[idx], idx, default)
def fetch_last(mask, default=I_INT):
reversed_idx = fetch_first(mask[::-1], default)
return np.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
def fetch_random(rand_key, mask, default=I_INT):
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = np.sum(mask)
cumsum = np.cumsum(mask)
target = np.random.randint(rand_key, shape=(), minval=0, maxval=true_cnt + 1)
return fetch_first(cumsum >= target, default)

0
examples/__init__.py Normal file
View File

71
examples/genome_test.py Normal file
View File

@@ -0,0 +1,71 @@
import time
import jax.random
from utils import Configer
from algorithms.neat.genome.genome import *
from algorithms.neat.species import SpeciesController
from algorithms.neat.genome.forward import create_forward_function
from algorithms.neat.genome.mutate import create_mutate_function
if __name__ == '__main__':
N = 10
pop_nodes, pop_connections, input_idx, output_idx = initialize_genomes(10000, N, 2, 1,
default_act=9, default_agg=0)
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# forward = create_forward_function(pop_nodes, pop_connections, 5, input_idx, output_idx, batch=True)
nodes, connections = pop_nodes[0], pop_connections[0]
forward = create_forward_function(pop_nodes, pop_connections, N, input_idx, output_idx, batch=True)
out = forward(inputs)
print(out.shape)
print(out)
config = Configer.load_config()
s_c = SpeciesController(config.neat)
s_c.speciate(pop_nodes, pop_connections, 0)
s_c.speciate(pop_nodes, pop_connections, 0)
print(s_c.genome_to_species)
start = time.time()
for i in range(100):
print(i)
s_c.speciate(pop_nodes, pop_connections, i)
print(time.time() - start)
seed = jax.random.PRNGKey(42)
mutate_func = create_mutate_function(config, input_idx, output_idx, batch=False)
print(nodes, connections, sep='\n')
print(*mutate_func(seed, nodes, connections, 100), sep='\n')
randseeds = jax.random.split(seed, 10000)
new_node_keys = jax.random.randint(randseeds[0], minval=0, maxval=10000, shape=(10000,))
batch_mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True)
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
print(pop_nodes, pop_connections, sep='\n')
start = time.time()
for i in range(100):
print(i)
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
print(time.time() - start)
print(nodes, connections, sep='\n')
nodes, connections = add_node(6, nodes, connections)
nodes, connections = add_node(7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = add_connection(6, 7, nodes, connections)
nodes, connections = add_connection(0, 7, nodes, connections)
nodes, connections = add_connection(1, 7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_connection(6, 7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_node(6, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_node(7, nodes, connections)
print(nodes, connections, sep='\n')

View File

@@ -0,0 +1,37 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
def plus1(x):
return x + 1
def minus1(x):
return x - 1
def func(rand_key, x):
r = jax.random.uniform(rand_key, shape=())
return jax.lax.cond(r > 0.5, plus1, minus1, x)
def func2(rand_key):
r = jax.random.uniform(rand_key, ())
if r < 0.3:
return 1
elif r < 0.5:
return 2
else:
return 3
key = random.PRNGKey(0)
print(func(key, 0))
batch_func = vmap(jit(func))
keys = random.split(key, 100)
print(batch_func(keys, jnp.zeros(100)))

40
examples/xor.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Callable, List
import jax
import numpy as np
from utils import Configer
from algorithms.neat import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
def evaluate(forward_func: Callable) -> List[float]:
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses.tolist() # returns a list
def main():
config = Configer.load_config()
pipeline = Pipeline(config)
forward_func = pipeline.ask(batch=True)
fitnesses = evaluate(forward_func)
pipeline.tell(fitnesses)
# for i in range(100):
# forward_func = pipeline.ask(batch=True)
# fitnesses = evaluate(forward_func)
# pipeline.tell(fitnesses)
if __name__ == '__main__':
main()

1
utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .config import Configer

Binary file not shown.

Binary file not shown.

Binary file not shown.

78
utils/config.py Normal file
View File

@@ -0,0 +1,78 @@
import json
import os
import warnings
from .dotdict import DotDict
class Configer:
@classmethod
def __load_default_config(cls):
par_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(par_dir, "./default_config.json")
return cls.__load_config(default_config_path)
@classmethod
def __load_config(cls, config_path):
with open(config_path, "r") as f:
text = "".join(f.readlines())
try:
j = json.loads(text)
except ValueError:
raise Exception("Invalid config")
return DotDict.from_dict(j, "root")
@classmethod
def __check_redundant_config(cls, default_config, config):
for key in config:
if key not in default_config:
warnings.warn(f"Redundant config: {key} in {config.name}")
continue
if isinstance(default_config[key], DotDict):
cls.__check_redundant_config(default_config[key], config[key])
@classmethod
def __complete_config(cls, default_config, config):
for key in default_config:
if key not in config:
config[key] = default_config[key]
continue
if isinstance(default_config[key], DotDict):
cls.__complete_config(default_config[key], config[key])
@classmethod
def __decorate_config(cls, config):
if config.neat.gene.activation.options == 'all':
config.neat.gene.activation.options = [
"sigmoid", "tanh", "sin", "gauss", "relu", "elu", "lelu", "selu", "softplus", "identity", "clamped",
"inv", "log", "exp", "abs", "hat", "square", "cube"
]
if isinstance(config.neat.gene.activation.options, str):
config.neat.gene.activation.options = [config.neat.gene.activation.options]
if config.neat.gene.aggregation.options == 'all':
config.neat.gene.aggregation.options = ["product", "sum", "max", "min", "median", "mean"]
if isinstance(config.neat.gene.aggregation.options, str):
config.neat.gene.aggregation.options = [config.neat.gene.aggregation.options]
@classmethod
def load_config(cls, config_path=None):
default_config = cls.__load_default_config()
if config_path is None:
config = DotDict("root")
elif not os.path.exists(config_path):
warnings.warn(f"config file {config_path} not exist!")
config = DotDict("root")
else:
config = cls.__load_config(config_path)
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
cls.__decorate_config(config)
return config
@classmethod
def write_config(cls, config, write_path):
text = json.dumps(config, indent=2)
with open(write_path, "w") as f:
f.write(text)

108
utils/default_config.json Normal file
View File

@@ -0,0 +1,108 @@
{
"basic": {
"num_inputs": 2,
"num_outputs": 1,
"init_maximum_nodes": 20,
"expands_coe": 1.5
},
"neat": {
"population": {
"fitness_criterion": "max",
"fitness_threshold": 43.9999,
"generation_limit": 100,
"pop_size": 1000,
"reset_on_extinction": "False"
},
"gene": {
"bias": {
"init_mean": 0.0,
"init_stdev": 1.0,
"max_value": 30.0,
"min_value": -30.0,
"mutate_power": 0.5,
"mutate_rate": 0.7,
"replace_rate": 0.1
},
"response": {
"init_mean": 1.0,
"init_stdev": 0.0,
"max_value": 30.0,
"min_value": -30.0,
"mutate_power": 0.0,
"mutate_rate": 0.0,
"replace_rate": 0.0
},
"activation": {
"default": "sigmoid",
"options": "sigmoid",
"mutate_rate": 0.01
},
"aggregation": {
"default": "sum",
"options": [
"product",
"sum",
"max",
"min",
"median",
"mean"
],
"mutate_rate": 0.01
},
"weight": {
"init_mean": 0.0,
"init_stdev": 1.0,
"max_value": 30.0,
"min_value": -30.0,
"mutate_power": 0.5,
"mutate_rate": 0.8,
"replace_rate": 0.1
},
"enabled": {
"mutate_rate": 0.01
}
},
"genome": {
"compatibility_disjoint_coefficient": 1.0,
"compatibility_weight_coefficient": 0.5,
"feedforward": "True",
"single_structural_mutation": "False",
"conn_add_prob": 0.5,
"conn_delete_prob": 0.5,
"node_add_prob": 0.2,
"node_delete_prob": 0.2
},
"species": {
"compatibility_threshold": 3.5,
"species_fitness_func": "max",
"max_stagnation": 20,
"species_elitism": 2,
"genome_elitism": 2,
"survival_threshold": 0.2,
"min_species_size": 1
}
},
"hyperneat": {
"substrate": {
"type": "feedforward",
"layers": [
3,
10,
10,
1
],
"x_lim": [
-5,
5
],
"y_lim": [
-5,
5
],
"threshold": 0.2,
"max_weight": 5.0
}
},
"es-hyperneat": {
}
}

61
utils/dotdict.py Normal file
View File

@@ -0,0 +1,61 @@
# DotDict For Config. Case Insensitive.
class DotDict(dict):
def __init__(self, name, *args, **kwargs):
super().__init__(*args, **kwargs)
self["name"] = name
def __getattr__(self, attr):
attr = attr.lower() # case insensitive
if attr in self:
return self[attr]
else:
raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'")
def __setattr__(self, attr, value):
attr = attr.lower() # case insensitive
if attr not in self:
raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'")
self[attr] = value
def __delattr__(self, attr):
attr = attr.lower() # case insensitive
if attr in self:
del self[attr]
else:
raise AttributeError(f"{self.__class__.__name__}-{self.name} object has no attribute '{attr}'")
@classmethod
def from_dict(cls, d, name):
if not isinstance(d, dict):
return d
dot_dict = cls(name)
for key, value in d.items():
key = key.lower() # case insensitive
if isinstance(value, dict):
dot_dict[key] = cls.from_dict(value, key)
else:
dot_dict[key] = value
if dot_dict[key] == "True": # Fuck! Json has no bool type!
dot_dict[key] = True
if dot_dict[key] == "False":
dot_dict[key] = False
if dot_dict[key] == "None":
dot_dict[key] = None
return dot_dict
if __name__ == '__main__':
nested_dict = {
"a": 1,
"b": {
"c": 2,
"ACDeef": {
"e": 3
}
}
}
dd = DotDict.from_dict(nested_dict, "root")
print(dd.b.acdeef.e) # 输出3