initial commit
This commit is contained in:
0
algorithms/__init__.py
Normal file
0
algorithms/__init__.py
Normal file
BIN
algorithms/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
algorithms/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
0
algorithms/hyper_neat/__init__.py
Normal file
0
algorithms/hyper_neat/__init__.py
Normal file
1
algorithms/neat/__init__.py
Normal file
1
algorithms/neat/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .pipeline import Pipeline
|
||||||
BIN
algorithms/neat/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
algorithms/neat/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/__pycache__/pipeline.cpython-39.pyc
Normal file
BIN
algorithms/neat/__pycache__/pipeline.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/__pycache__/species.cpython-39.pyc
Normal file
BIN
algorithms/neat/__pycache__/species.cpython-39.pyc
Normal file
Binary file not shown.
4
algorithms/neat/genome/__init__.py
Normal file
4
algorithms/neat/genome/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .genome import create_initialize_function
|
||||||
|
from .distance import distance
|
||||||
|
from .mutate import create_mutate_function
|
||||||
|
from .forward import create_forward_function
|
||||||
BIN
algorithms/neat/genome/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/activations.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/activations.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/aggregations.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/aggregations.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/distance.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/distance.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/forward.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/forward.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/genome.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/genome.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/graph.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/graph.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/mutate.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/mutate.cpython-39.pyc
Normal file
Binary file not shown.
BIN
algorithms/neat/genome/__pycache__/utils.cpython-39.pyc
Normal file
BIN
algorithms/neat/genome/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
138
algorithms/neat/genome/activations.py
Normal file
138
algorithms/neat/genome/activations.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def sigmoid_act(z):
|
||||||
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
|
return 1 / (1 + jnp.exp(-z))
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def tanh_act(z):
|
||||||
|
z = jnp.clip(z * 2.5, -60, 60)
|
||||||
|
return jnp.tanh(z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def sin_act(z):
|
||||||
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
|
return jnp.sin(z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def gauss_act(z):
|
||||||
|
z = jnp.clip(z, -3.4, 3.4)
|
||||||
|
return jnp.exp(-5 * z ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def relu_act(z):
|
||||||
|
return jnp.maximum(z, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def elu_act(z):
|
||||||
|
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def lelu_act(z):
|
||||||
|
leaky = 0.005
|
||||||
|
return jnp.where(z > 0, z, leaky * z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def selu_act(z):
|
||||||
|
lam = 1.0507009873554804934193349852946
|
||||||
|
alpha = 1.6732632423543772848170429916717
|
||||||
|
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def softplus_act(z):
|
||||||
|
z = jnp.clip(z * 5, -60, 60)
|
||||||
|
return 0.2 * jnp.log(1 + jnp.exp(z))
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def identity_act(z):
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def clamped_act(z):
|
||||||
|
return jnp.clip(z, -1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def inv_act(z):
|
||||||
|
return 1 / z
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def log_act(z):
|
||||||
|
z = jnp.maximum(z, 1e-7)
|
||||||
|
return jnp.log(z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def exp_act(z):
|
||||||
|
z = jnp.clip(z, -60, 60)
|
||||||
|
return jnp.exp(z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def abs_act(z):
|
||||||
|
return jnp.abs(z)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def hat_act(z):
|
||||||
|
return jnp.maximum(0, 1 - jnp.abs(z))
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def square_act(z):
|
||||||
|
return z ** 2
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def cube_act(z):
|
||||||
|
return z ** 3
|
||||||
|
|
||||||
|
|
||||||
|
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
|
||||||
|
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
|
||||||
|
|
||||||
|
act_name2key = {
|
||||||
|
'sigmoid': 0,
|
||||||
|
'tanh': 1,
|
||||||
|
'sin': 2,
|
||||||
|
'gauss': 3,
|
||||||
|
'relu': 4,
|
||||||
|
'elu': 5,
|
||||||
|
'lelu': 6,
|
||||||
|
'selu': 7,
|
||||||
|
'softplus': 8,
|
||||||
|
'identity': 9,
|
||||||
|
'clamped': 10,
|
||||||
|
'inv': 11,
|
||||||
|
'log': 12,
|
||||||
|
'exp': 13,
|
||||||
|
'abs': 14,
|
||||||
|
'hat': 15,
|
||||||
|
'square': 16,
|
||||||
|
'cube': 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def act(idx, z):
|
||||||
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
|
# change idx from float to int
|
||||||
|
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||||
|
|
||||||
|
|
||||||
|
vectorized_act = jax.vmap(act, in_axes=(0, 0))
|
||||||
109
algorithms/neat/genome/aggregations.py
Normal file
109
algorithms/neat/genome/aggregations.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
aggregations, two special case need to consider:
|
||||||
|
1. extra 0s
|
||||||
|
2. full of 0s
|
||||||
|
"""
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax import jit
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def sum_agg(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
|
return jnp.sum(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def product_agg(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), 1, z)
|
||||||
|
return jnp.prod(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def max_agg(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||||
|
return jnp.max(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def min_agg(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||||
|
return jnp.min(z, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def maxabs_agg(z):
|
||||||
|
z = jnp.where(jnp.isnan(z), 0, z)
|
||||||
|
abs_z = jnp.abs(z)
|
||||||
|
max_abs_index = jnp.argmax(abs_z)
|
||||||
|
return z[max_abs_index]
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def median_agg(z):
|
||||||
|
|
||||||
|
non_zero_mask = ~jnp.isnan(z)
|
||||||
|
n = jnp.sum(non_zero_mask, axis=0)
|
||||||
|
|
||||||
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||||
|
sorted_valid_values = jnp.sort(z)
|
||||||
|
|
||||||
|
def _even_case():
|
||||||
|
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
|
||||||
|
|
||||||
|
def _odd_case():
|
||||||
|
return sorted_valid_values[n // 2]
|
||||||
|
|
||||||
|
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
|
||||||
|
|
||||||
|
return median
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mean_agg(z):
|
||||||
|
non_zero_mask = ~jnp.isnan(z)
|
||||||
|
valid_values_sum = sum_agg(z)
|
||||||
|
valid_values_count = jnp.sum(non_zero_mask, axis=0)
|
||||||
|
mean_without_zeros = valid_values_sum / valid_values_count
|
||||||
|
return mean_without_zeros
|
||||||
|
|
||||||
|
|
||||||
|
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
||||||
|
|
||||||
|
agg_name2key = {
|
||||||
|
'sum': 0,
|
||||||
|
'product': 1,
|
||||||
|
'max': 2,
|
||||||
|
'min': 3,
|
||||||
|
'maxabs': 4,
|
||||||
|
'median': 5,
|
||||||
|
'mean': 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def agg(idx, z):
|
||||||
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
|
|
||||||
|
def full_zero():
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
def not_full_zero():
|
||||||
|
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
|
||||||
|
|
||||||
|
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero)
|
||||||
|
|
||||||
|
|
||||||
|
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
||||||
|
for names in agg_name2key.keys():
|
||||||
|
print(names, agg(agg_name2key[names], array))
|
||||||
|
|
||||||
|
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
|
||||||
|
for names in agg_name2key.keys():
|
||||||
|
print(names, agg(agg_name2key[names], array2))
|
||||||
151
algorithms/neat/genome/crossover.py
Normal file
151
algorithms/neat/genome/crossover.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
# from .utils import flatten_connections, unflatten_connections
|
||||||
|
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
|
@vmap
|
||||||
|
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||||
|
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
crossover a batch of genomes
|
||||||
|
:param randkeys: batches of random keys
|
||||||
|
:param batch_nodes1:
|
||||||
|
:param batch_connections1:
|
||||||
|
:param batch_nodes2:
|
||||||
|
:param batch_connections2:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||||
|
-> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
use genome1 and genome2 to generate a new genome
|
||||||
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
|
:param randkey:
|
||||||
|
:param nodes1:
|
||||||
|
:param connections1:
|
||||||
|
:param nodes2:
|
||||||
|
:param connections2:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||||
|
|
||||||
|
# crossover nodes
|
||||||
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
|
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||||
|
new_nodes = jnp.where(jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||||
|
|
||||||
|
# crossover connections
|
||||||
|
cons1 = flatten_connections(keys1, connections1)
|
||||||
|
cons2 = flatten_connections(keys2, connections2)
|
||||||
|
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||||
|
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||||
|
new_cons = jnp.where(jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||||
|
new_cons = unflatten_connections(len(keys1), new_cons)
|
||||||
|
|
||||||
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=['gene_type'])
|
||||||
|
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||||
|
"""
|
||||||
|
make ar2 align with ar1.
|
||||||
|
:param seq1:
|
||||||
|
:param seq2:
|
||||||
|
:param ar2:
|
||||||
|
:param gene_type:
|
||||||
|
:return:
|
||||||
|
align means to intersect part of ar2 will be at the same position as ar1,
|
||||||
|
non-intersect part of ar2 will be set to Nan
|
||||||
|
"""
|
||||||
|
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||||
|
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||||
|
|
||||||
|
if gene_type == 'connection':
|
||||||
|
mask = jnp.all(mask, axis=2)
|
||||||
|
|
||||||
|
intersect_mask = mask.any(axis=1)
|
||||||
|
idx = jnp.arange(0, len(seq1))
|
||||||
|
idx_fixed = jnp.dot(mask, idx)
|
||||||
|
|
||||||
|
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||||
|
|
||||||
|
return refactor_ar2
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||||
|
"""
|
||||||
|
crossover two genes
|
||||||
|
:param rand_key:
|
||||||
|
:param g1:
|
||||||
|
:param g2:
|
||||||
|
:return:
|
||||||
|
only gene with the same key will be crossover, thus don't need to consider change key
|
||||||
|
"""
|
||||||
|
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||||
|
return jnp.where(r > 0.5, g1, g2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
randkey = jax.random.PRNGKey(40)
|
||||||
|
nodes1 = np.array([
|
||||||
|
[4, 1, 1, 1, 1],
|
||||||
|
[6, 2, 2, 2, 2],
|
||||||
|
[1, 3, 3, 3, 3],
|
||||||
|
[5, 4, 4, 4, 4],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
])
|
||||||
|
nodes2 = np.array([
|
||||||
|
[4, 1.5, 1.5, 1.5, 1.5],
|
||||||
|
[7, 3.5, 3.5, 3.5, 3.5],
|
||||||
|
[5, 4.5, 4.5, 4.5, 4.5],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||||
|
])
|
||||||
|
weights1 = np.array([
|
||||||
|
[
|
||||||
|
[1, 2, 3, 4., np.nan],
|
||||||
|
[5, np.nan, 7, 8, np.nan],
|
||||||
|
[9, 10, 11, 12, np.nan],
|
||||||
|
[13, 14, 15, 16, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 1, 0, 1, np.nan],
|
||||||
|
[0, np.nan, 0, 1, np.nan],
|
||||||
|
[0, 1, 0, 1, np.nan],
|
||||||
|
[0, 1, 0, 1, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
weights2 = np.array([
|
||||||
|
[
|
||||||
|
[1.5, 2.5, 3.5, np.nan, np.nan],
|
||||||
|
[3.5, 4.5, 5.5, np.nan, np.nan],
|
||||||
|
[6.5, 7.5, 8.5, np.nan, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 0, 1, np.nan, np.nan],
|
||||||
|
[1, 0, 1, np.nan, np.nan],
|
||||||
|
[1, 0, 1, np.nan, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||||
|
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
|
||||||
|
print(*res, sep='\n')
|
||||||
71
algorithms/neat/genome/distance.py
Normal file
71
algorithms/neat/genome/distance.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) -> Array:
|
||||||
|
"""
|
||||||
|
Calculate the distance between two genomes.
|
||||||
|
nodes are a 2-d array with shape (N, 5), its columns are [key, bias, response, act, agg]
|
||||||
|
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_distance = gene_distance(nodes1, nodes2, 'node')
|
||||||
|
|
||||||
|
# refactor connections
|
||||||
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
|
cons1 = flatten_connections(keys1, connections1)
|
||||||
|
cons2 = flatten_connections(keys2, connections2)
|
||||||
|
|
||||||
|
connection_distance = gene_distance(cons1, cons2, 'connection')
|
||||||
|
return node_distance + connection_distance
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=["gene_type"])
|
||||||
|
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||||
|
if gene_type == 'node':
|
||||||
|
keys1, keys2 = ar1[:, :1], ar2[:, :1]
|
||||||
|
else: # connection
|
||||||
|
keys1, keys2 = ar1[:, :2], ar2[:, :2]
|
||||||
|
|
||||||
|
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||||
|
nodes = jnp.concatenate((ar1, ar2), axis=0)
|
||||||
|
sorted_nodes = nodes[n_sorted_indices]
|
||||||
|
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||||
|
|
||||||
|
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
||||||
|
if gene_type == 'node':
|
||||||
|
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||||
|
else: # connection
|
||||||
|
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||||
|
|
||||||
|
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance)
|
||||||
|
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1])
|
||||||
|
|
||||||
|
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
|
||||||
|
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
|
||||||
|
|
||||||
|
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||||
|
return val / jnp.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(vmap, in_axes=(0, 0))
|
||||||
|
def homologous_node_distance(n1, n2):
|
||||||
|
d = 0
|
||||||
|
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||||
|
d += jnp.abs(n1[2] - n2[2]) # response
|
||||||
|
d += n1[3] != n2[3] # activation
|
||||||
|
d += n1[4] != n2[4]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@partial(vmap, in_axes=(0, 0))
|
||||||
|
def homologous_connection_distance(c1, c2):
|
||||||
|
d = 0
|
||||||
|
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||||
|
d += c1[3] != c2[3] # enable
|
||||||
|
return d
|
||||||
171
algorithms/neat/genome/forward.py
Normal file
171
algorithms/neat/genome/forward.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import Array, numpy as jnp
|
||||||
|
from jax import jit, vmap
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from .aggregations import agg
|
||||||
|
from .activations import act
|
||||||
|
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
|
||||||
|
from .utils import I_INT
|
||||||
|
|
||||||
|
|
||||||
|
def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||||
|
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
|
||||||
|
"""
|
||||||
|
create forward function for different situations
|
||||||
|
|
||||||
|
:param nodes: shape (N, 5) or (pop_size, N, 5)
|
||||||
|
:param connections: shape (2, N, N) or (pop_size, 2, N, N)
|
||||||
|
:param N:
|
||||||
|
:param input_idx:
|
||||||
|
:param output_idx:
|
||||||
|
:param batch: using batch or not
|
||||||
|
:param debug: debug mode
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
cal_seqs = topological_sort(nodes, connections)
|
||||||
|
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
||||||
|
cal_seqs, nodes, connections)
|
||||||
|
|
||||||
|
if nodes.ndim == 2: # single genome
|
||||||
|
cal_seqs = topological_sort(nodes, connections)
|
||||||
|
if not batch:
|
||||||
|
return lambda inputs: forward_single(inputs, N, input_idx, output_idx,
|
||||||
|
cal_seqs, nodes, connections)
|
||||||
|
else:
|
||||||
|
return lambda batch_inputs: forward_batch(batch_inputs, N, input_idx, output_idx,
|
||||||
|
cal_seqs, nodes, connections)
|
||||||
|
elif nodes.ndim == 3: # pop genome
|
||||||
|
pop_cal_seqs = batch_topological_sort(nodes, connections)
|
||||||
|
if not batch:
|
||||||
|
return lambda inputs: pop_forward_single(inputs, N, input_idx, output_idx,
|
||||||
|
pop_cal_seqs, nodes, connections)
|
||||||
|
else:
|
||||||
|
return lambda batch_inputs: pop_forward_batch(batch_inputs, N, input_idx, output_idx,
|
||||||
|
pop_cal_seqs, nodes, connections)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
||||||
|
|
||||||
|
|
||||||
|
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
|
||||||
|
@partial(jit, static_argnames=['N'])
|
||||||
|
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||||
|
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
jax forward for single input shaped (input_num, )
|
||||||
|
nodes, connections are single genome
|
||||||
|
|
||||||
|
:argument inputs: (input_num, )
|
||||||
|
:argument N: int
|
||||||
|
:argument input_idx: (input_num, )
|
||||||
|
:argument output_idx: (output_num, )
|
||||||
|
:argument cal_seqs: (N, )
|
||||||
|
:argument nodes: (N, 5)
|
||||||
|
:argument connections: (2, N, N)
|
||||||
|
|
||||||
|
:return (output_num, )
|
||||||
|
"""
|
||||||
|
ini_vals = jnp.full((N,), jnp.nan)
|
||||||
|
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||||
|
|
||||||
|
def scan_body(carry, i):
|
||||||
|
def hit():
|
||||||
|
ins = carry * connections[0, :, i]
|
||||||
|
z = agg(nodes[i, 4], ins)
|
||||||
|
z = z * nodes[i, 2] + nodes[i, 1]
|
||||||
|
z = act(nodes[i, 3], z)
|
||||||
|
|
||||||
|
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||||
|
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
|
||||||
|
return new_vals
|
||||||
|
|
||||||
|
def miss():
|
||||||
|
return carry
|
||||||
|
|
||||||
|
return jax.lax.cond(i == I_INT, miss, hit), None
|
||||||
|
|
||||||
|
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||||
|
|
||||||
|
return vals[output_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
|
||||||
|
ini_vals = jnp.full((N,), jnp.nan)
|
||||||
|
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||||
|
vals = ini_vals
|
||||||
|
for i in cal_seqs:
|
||||||
|
if i == I_INT:
|
||||||
|
break
|
||||||
|
ins = vals * connections[0, :, i]
|
||||||
|
z = agg(nodes[i, 4], ins)
|
||||||
|
z = z * nodes[i, 2] + nodes[i, 1]
|
||||||
|
z = act(nodes[i, 3], z)
|
||||||
|
|
||||||
|
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||||
|
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
|
||||||
|
|
||||||
|
return vals[output_idx]
|
||||||
|
|
||||||
|
|
||||||
|
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||||
|
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||||
|
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
jax forward for batch_inputs shaped (batch_size, input_num)
|
||||||
|
nodes, connections are single genome
|
||||||
|
|
||||||
|
:argument batch_inputs: (batch_size, input_num)
|
||||||
|
:argument N: int
|
||||||
|
:argument input_idx: (input_num, )
|
||||||
|
:argument output_idx: (output_num, )
|
||||||
|
:argument cal_seqs: (N, )
|
||||||
|
:argument nodes: (N, 5)
|
||||||
|
:argument connections: (2, N, N)
|
||||||
|
|
||||||
|
:return (batch_size, output_num)
|
||||||
|
"""
|
||||||
|
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||||
|
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||||
|
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
jax forward for single input shaped (input_num, )
|
||||||
|
pop_nodes, pop_connections are population of genomes
|
||||||
|
|
||||||
|
:argument inputs: (input_num, )
|
||||||
|
:argument N: int
|
||||||
|
:argument input_idx: (input_num, )
|
||||||
|
:argument output_idx: (output_num, )
|
||||||
|
:argument pop_cal_seqs: (pop_size, N)
|
||||||
|
:argument pop_nodes: (pop_size, N, 5)
|
||||||
|
:argument pop_connections: (pop_size, 2, N, N)
|
||||||
|
|
||||||
|
:return (pop_size, output_num)
|
||||||
|
"""
|
||||||
|
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
|
||||||
|
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||||
|
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
jax forward for batch input shaped (batch, input_num)
|
||||||
|
pop_nodes, pop_connections are population of genomes
|
||||||
|
|
||||||
|
:argument batch_inputs: (batch_size, input_num)
|
||||||
|
:argument N: int
|
||||||
|
:argument input_idx: (input_num, )
|
||||||
|
:argument output_idx: (output_num, )
|
||||||
|
:argument pop_cal_seqs: (pop_size, N)
|
||||||
|
:argument pop_nodes: (pop_size, N, 5)
|
||||||
|
:argument pop_connections: (pop_size, 2, N, N)
|
||||||
|
|
||||||
|
:return (pop_size, batch_size, output_num)
|
||||||
|
"""
|
||||||
|
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
|
||||||
195
algorithms/neat/genome/genome.py
Normal file
195
algorithms/neat/genome/genome.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
Vectorization of genome representation.
|
||||||
|
|
||||||
|
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
|
||||||
|
|
||||||
|
1. N is a pre-set value that determines the maximum number of nodes in the network, and will increase if the genome becomes
|
||||||
|
too large to be represented by the current value of N.
|
||||||
|
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
|
||||||
|
(act), and aggregation function (agg).
|
||||||
|
3. connections is an array of shape (2, N, N), dtype=float, with the first axis representing weight and connection enabled
|
||||||
|
status.
|
||||||
|
Empty nodes or connections are represented using np.nan.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Tuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from jax import numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
from jax import Array
|
||||||
|
|
||||||
|
from algorithms.neat.genome.utils import fetch_first, fetch_last
|
||||||
|
|
||||||
|
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||||
|
|
||||||
|
|
||||||
|
def create_initialize_function(config):
|
||||||
|
pop_size = config.neat.population.pop_size
|
||||||
|
N = config.basic.init_maximum_nodes
|
||||||
|
num_inputs = config.basic.num_inputs
|
||||||
|
num_outputs = config.basic.num_outputs
|
||||||
|
default_bias = config.neat.gene.bias.init_mean
|
||||||
|
default_response = config.neat.gene.response.init_mean
|
||||||
|
# default_act = config.neat.gene.activation.default
|
||||||
|
# default_agg = config.neat.gene.aggregation.default
|
||||||
|
default_act = 0
|
||||||
|
default_agg = 0
|
||||||
|
default_weight = config.neat.gene.weight.init_mean
|
||||||
|
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
||||||
|
default_act, default_agg, default_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_genomes(pop_size: int,
|
||||||
|
N: int,
|
||||||
|
num_inputs: int, num_outputs: int,
|
||||||
|
default_bias: float = 0.0,
|
||||||
|
default_response: float = 1.0,
|
||||||
|
default_act: int = 0,
|
||||||
|
default_agg: int = 0,
|
||||||
|
default_weight: float = 1.0) \
|
||||||
|
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Initialize genomes with default values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pop_size (int): Number of genomes to initialize.
|
||||||
|
N (int): Maximum number of nodes in the network.
|
||||||
|
num_inputs (int): Number of input nodes.
|
||||||
|
num_outputs (int): Number of output nodes.
|
||||||
|
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
|
||||||
|
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
|
||||||
|
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
|
||||||
|
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
|
||||||
|
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
|
||||||
|
"""
|
||||||
|
# Reserve one row for potential mutation adding an extra node
|
||||||
|
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
|
||||||
|
f"{num_inputs} and output_size: {num_outputs}!"
|
||||||
|
|
||||||
|
pop_nodes = np.full((pop_size, N, 5), np.nan)
|
||||||
|
pop_connections = np.full((pop_size, 2, N, N), np.nan)
|
||||||
|
input_idx = np.arange(num_inputs)
|
||||||
|
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||||
|
|
||||||
|
pop_nodes[:, input_idx, 0] = input_idx
|
||||||
|
pop_nodes[:, output_idx, 0] = output_idx
|
||||||
|
|
||||||
|
pop_nodes[:, output_idx, 1] = default_bias
|
||||||
|
pop_nodes[:, output_idx, 2] = default_response
|
||||||
|
pop_nodes[:, output_idx, 3] = default_act
|
||||||
|
pop_nodes[:, output_idx, 4] = default_agg
|
||||||
|
|
||||||
|
for i in input_idx:
|
||||||
|
for j in output_idx:
|
||||||
|
pop_connections[:, 0, i, j] = default_weight
|
||||||
|
pop_connections[:, 1, i, j] = 1
|
||||||
|
|
||||||
|
return pop_nodes, pop_connections, input_idx, output_idx
|
||||||
|
|
||||||
|
|
||||||
|
def expand(pop_nodes: NDArray, pop_connections: NDArray, new_N: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Expand the genome to accommodate more nodes.
|
||||||
|
:param pop_nodes:
|
||||||
|
:param pop_connections:
|
||||||
|
:param new_N:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pop_size, old_N = pop_nodes.shape[0], pop_nodes.shape[1]
|
||||||
|
|
||||||
|
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||||
|
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||||
|
|
||||||
|
new_pop_connections = np.full((pop_size, 2, new_N, new_N), np.nan)
|
||||||
|
new_pop_connections[:, :, :old_N, :old_N] = pop_connections
|
||||||
|
return new_pop_nodes, new_pop_connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
||||||
|
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
add a new node to the genome.
|
||||||
|
"""
|
||||||
|
exist_keys = nodes[:, 0]
|
||||||
|
idx = fetch_first(jnp.isnan(exist_keys))
|
||||||
|
nodes = nodes.at[idx].set(jnp.array([new_node_key, bias, response, act, agg]))
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_node(node_key: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
delete a node from the genome. only delete the node, regardless of connections.
|
||||||
|
"""
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
idx = fetch_first(node_keys == node_key)
|
||||||
|
return delete_node_by_idx(idx, nodes, connections)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
delete a node from the genome. only delete the node, regardless of connections.
|
||||||
|
"""
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
# move the last node to the deleted node's position
|
||||||
|
last_real_idx = fetch_last(~jnp.isnan(node_keys))
|
||||||
|
nodes = nodes.at[idx].set(nodes[last_real_idx])
|
||||||
|
nodes = nodes.at[last_real_idx].set(EMPTY_NODE)
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
|
||||||
|
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
add a new connection to the genome.
|
||||||
|
"""
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
from_idx = fetch_first(node_keys == from_node)
|
||||||
|
to_idx = fetch_first(node_keys == to_node)
|
||||||
|
return add_connection_by_idx(from_idx, to_idx, nodes, connections, weight, enabled)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def add_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array,
|
||||||
|
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
add a new connection to the genome.
|
||||||
|
"""
|
||||||
|
connections = connections.at[:, from_idx, to_idx].set(jnp.array([weight, enabled]))
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_connection(from_node: int, to_node: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
delete a connection from the genome.
|
||||||
|
"""
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
from_idx = fetch_first(node_keys == from_node)
|
||||||
|
to_idx = fetch_first(node_keys == to_node)
|
||||||
|
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_connection_by_idx(from_idx: int, to_idx: int, nodes: Array, connections: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
delete a connection from the genome.
|
||||||
|
"""
|
||||||
|
connections = connections.at[:, from_idx, to_idx].set(np.nan)
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# pop_nodes, pop_connections, input_keys, output_keys = initialize_genomes(100, 5, 2, 1)
|
||||||
|
# print(pop_nodes, pop_connections)
|
||||||
198
algorithms/neat/genome/graph.py
Normal file
198
algorithms/neat/genome/graph.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Some graph algorithms implemented in jax.
|
||||||
|
Only used in feed-forward networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
# from .utils import fetch_first, I_INT
|
||||||
|
from algorithms.neat.genome.utils import fetch_first, I_INT
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
a jit-able version of topological_sort! that's crazy!
|
||||||
|
:param nodes: nodes array
|
||||||
|
:param connections: connections array
|
||||||
|
:return: topological sorted sequence
|
||||||
|
|
||||||
|
Example:
|
||||||
|
nodes = jnp.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[2],
|
||||||
|
[3]
|
||||||
|
])
|
||||||
|
connections = jnp.array([
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
||||||
|
"""
|
||||||
|
connections_enable = connections[1, :, :] == 1
|
||||||
|
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||||
|
res = jnp.full(in_degree.shape, I_INT)
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
def scan_body(carry, _):
|
||||||
|
res_, idx_, in_degree_ = carry
|
||||||
|
i = fetch_first(in_degree_ == 0.)
|
||||||
|
|
||||||
|
def hit():
|
||||||
|
# add to res and flag it is already in it
|
||||||
|
new_res = res_.at[idx_].set(i)
|
||||||
|
new_idx = idx_ + 1
|
||||||
|
new_in_degree = in_degree_.at[i].set(-1)
|
||||||
|
|
||||||
|
# decrease in_degree of all its children
|
||||||
|
children = connections_enable[i, :]
|
||||||
|
new_in_degree = jnp.where(children, new_in_degree - 1, new_in_degree)
|
||||||
|
return new_res, new_idx, new_in_degree
|
||||||
|
|
||||||
|
def miss():
|
||||||
|
return res_, idx_, in_degree_
|
||||||
|
|
||||||
|
return jax.lax.cond(i == I_INT, miss, hit), None
|
||||||
|
|
||||||
|
scan_res, _ = jax.lax.scan(scan_body, (res, idx, in_degree), None, length=in_degree.shape[0])
|
||||||
|
res, _, _ = scan_res
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
# @jit
|
||||||
|
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
|
||||||
|
connections_enable = connections[1, :, :] == 1
|
||||||
|
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||||
|
res = jnp.full(in_degree.shape, I_INT)
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for _ in range(in_degree.shape[0]):
|
||||||
|
i = fetch_first(in_degree == 0.)
|
||||||
|
if i == I_INT:
|
||||||
|
break
|
||||||
|
res = res.at[idx].set(i)
|
||||||
|
idx += 1
|
||||||
|
in_degree = in_degree.at[i].set(-1)
|
||||||
|
children = connections_enable[i, :]
|
||||||
|
in_degree = jnp.where(children, in_degree - 1, in_degree)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@vmap
|
||||||
|
def batch_topological_sort(nodes: Array, connections: Array) -> Array:
|
||||||
|
"""
|
||||||
|
batch version of topological_sort
|
||||||
|
:param nodes:
|
||||||
|
:param connections:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return topological_sort(nodes, connections)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||||
|
"""
|
||||||
|
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||||
|
|
||||||
|
:param nodes: JAX array
|
||||||
|
The array of nodes.
|
||||||
|
:param connections: JAX array
|
||||||
|
The array of connections.
|
||||||
|
:param from_idx: int
|
||||||
|
The index of the starting node.
|
||||||
|
:param to_idx: int
|
||||||
|
The index of the ending node.
|
||||||
|
:return: JAX array
|
||||||
|
An array indicating if there is a cycle caused by the new connection.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
nodes = jnp.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[2],
|
||||||
|
[3]
|
||||||
|
])
|
||||||
|
connections = jnp.array([
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
check_cycles(nodes, connections, 3, 2) -> True
|
||||||
|
check_cycles(nodes, connections, 2, 3) -> False
|
||||||
|
check_cycles(nodes, connections, 0, 3) -> False
|
||||||
|
check_cycles(nodes, connections, 1, 0) -> False
|
||||||
|
"""
|
||||||
|
connections_enable = connections[1, :, :] == 1
|
||||||
|
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||||
|
nodes_visited = jnp.full(nodes.shape[0], False)
|
||||||
|
nodes_visited = nodes_visited.at[to_idx].set(True)
|
||||||
|
|
||||||
|
def scan_body(visited, _):
|
||||||
|
new_visited = jnp.dot(visited, connections_enable)
|
||||||
|
new_visited = jnp.logical_or(visited, new_visited)
|
||||||
|
return new_visited, None
|
||||||
|
|
||||||
|
nodes_visited, _ = jax.lax.scan(scan_body, nodes_visited, None, length=nodes_visited.shape[0])
|
||||||
|
|
||||||
|
return nodes_visited[from_idx]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
nodes = jnp.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[2],
|
||||||
|
[3],
|
||||||
|
[jnp.nan]
|
||||||
|
])
|
||||||
|
connections = jnp.array([
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0, jnp.nan],
|
||||||
|
[0, 0, 1, 1, jnp.nan],
|
||||||
|
[0, 0, 0, 1, jnp.nan],
|
||||||
|
[0, 0, 0, 0, jnp.nan],
|
||||||
|
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 0, 1, 0, jnp.nan],
|
||||||
|
[0, 0, 1, 1, jnp.nan],
|
||||||
|
[0, 0, 0, 1, jnp.nan],
|
||||||
|
[0, 0, 0, 0, jnp.nan],
|
||||||
|
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(topological_sort_debug(nodes, connections))
|
||||||
|
print(topological_sort(nodes, connections))
|
||||||
|
|
||||||
|
print(check_cycles(nodes, connections, 3, 2))
|
||||||
|
print(check_cycles(nodes, connections, 2, 3))
|
||||||
|
print(check_cycles(nodes, connections, 0, 3))
|
||||||
|
print(check_cycles(nodes, connections, 1, 0))
|
||||||
538
algorithms/neat/genome/mutate.py
Normal file
538
algorithms/neat/genome/mutate.py
Normal file
@@ -0,0 +1,538 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
|
||||||
|
from .utils import fetch_random, fetch_first, I_INT
|
||||||
|
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
|
||||||
|
from .graph import check_cycles
|
||||||
|
|
||||||
|
|
||||||
|
def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||||
|
"""
|
||||||
|
create mutate function for different situations
|
||||||
|
:param output_keys:
|
||||||
|
:param input_keys:
|
||||||
|
:param config:
|
||||||
|
:param batch: mutate for population or not
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
bias = config.neat.gene.bias
|
||||||
|
bias_default = bias.init_mean
|
||||||
|
bias_mean = bias.init_mean
|
||||||
|
bias_std = bias.init_stdev
|
||||||
|
bias_mutate_strength = bias.mutate_power
|
||||||
|
bias_mutate_rate = bias.mutate_rate
|
||||||
|
bias_replace_rate = bias.replace_rate
|
||||||
|
|
||||||
|
response = config.neat.gene.response
|
||||||
|
response_default = response.init_mean
|
||||||
|
response_mean = response.init_mean
|
||||||
|
response_std = response.init_stdev
|
||||||
|
response_mutate_strength = response.mutate_power
|
||||||
|
response_mutate_rate = response.mutate_rate
|
||||||
|
response_replace_rate = response.replace_rate
|
||||||
|
|
||||||
|
weight = config.neat.gene.weight
|
||||||
|
weight_mean = weight.init_mean
|
||||||
|
weight_std = weight.init_stdev
|
||||||
|
weight_mutate_strength = weight.mutate_power
|
||||||
|
weight_mutate_rate = weight.mutate_rate
|
||||||
|
weight_replace_rate = weight.replace_rate
|
||||||
|
|
||||||
|
activation = config.neat.gene.activation
|
||||||
|
# act_default = activation.default
|
||||||
|
act_default = 0
|
||||||
|
act_range = len(activation.options)
|
||||||
|
act_replace_rate = activation.mutate_rate
|
||||||
|
|
||||||
|
aggregation = config.neat.gene.aggregation
|
||||||
|
# agg_default = aggregation.default
|
||||||
|
agg_default = 0
|
||||||
|
agg_range = len(aggregation.options)
|
||||||
|
agg_replace_rate = aggregation.mutate_rate
|
||||||
|
|
||||||
|
enabled = config.neat.gene.enabled
|
||||||
|
enabled_reverse_rate = enabled.mutate_rate
|
||||||
|
|
||||||
|
genome = config.neat.genome
|
||||||
|
add_node_rate = genome.node_add_prob
|
||||||
|
delete_node_rate = genome.node_delete_prob
|
||||||
|
add_connection_rate = genome.conn_add_prob
|
||||||
|
delete_connection_rate = genome.conn_delete_prob
|
||||||
|
single_structure_mutate = genome.single_structural_mutation
|
||||||
|
|
||||||
|
if not batch:
|
||||||
|
return lambda rand_key, nodes, connections, new_node_key: \
|
||||||
|
mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
||||||
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
||||||
|
bias_replace_rate, response_default, response_mean, response_std,
|
||||||
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||||
|
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
||||||
|
weight_replace_rate, act_default, act_range, act_replace_rate,
|
||||||
|
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
|
||||||
|
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
||||||
|
single_structure_mutate)
|
||||||
|
else:
|
||||||
|
batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31))
|
||||||
|
return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \
|
||||||
|
batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys,
|
||||||
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
||||||
|
bias_replace_rate, response_default, response_mean, response_std,
|
||||||
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||||
|
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
||||||
|
weight_replace_rate, act_default, act_range, act_replace_rate,
|
||||||
|
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
|
||||||
|
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
||||||
|
single_structure_mutate)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=["single_structure_mutate"])
|
||||||
|
def mutate(rand_key: Array,
|
||||||
|
nodes: Array,
|
||||||
|
connections: Array,
|
||||||
|
new_node_key: int,
|
||||||
|
input_keys: Array,
|
||||||
|
output_keys: Array,
|
||||||
|
bias_default: float = 0,
|
||||||
|
bias_mean: float = 0,
|
||||||
|
bias_std: float = 1,
|
||||||
|
bias_mutate_strength: float = 0.5,
|
||||||
|
bias_mutate_rate: float = 0.7,
|
||||||
|
bias_replace_rate: float = 0.1,
|
||||||
|
response_default: float = 1,
|
||||||
|
response_mean: float = 1.,
|
||||||
|
response_std: float = 0.,
|
||||||
|
response_mutate_strength: float = 0.,
|
||||||
|
response_mutate_rate: float = 0.,
|
||||||
|
response_replace_rate: float = 0.,
|
||||||
|
weight_mean: float = 0.,
|
||||||
|
weight_std: float = 1.,
|
||||||
|
weight_mutate_strength: float = 0.5,
|
||||||
|
weight_mutate_rate: float = 0.7,
|
||||||
|
weight_replace_rate: float = 0.1,
|
||||||
|
act_default: int = 0,
|
||||||
|
act_range: int = 5,
|
||||||
|
act_replace_rate: float = 0.1,
|
||||||
|
agg_default: int = 0,
|
||||||
|
agg_range: int = 5,
|
||||||
|
agg_replace_rate: float = 0.1,
|
||||||
|
enabled_reverse_rate: float = 0.1,
|
||||||
|
add_node_rate: float = 0.2,
|
||||||
|
delete_node_rate: float = 0.2,
|
||||||
|
add_connection_rate: float = 0.4,
|
||||||
|
delete_connection_rate: float = 0.4,
|
||||||
|
single_structure_mutate: bool = True):
|
||||||
|
"""
|
||||||
|
:param output_keys:
|
||||||
|
:param input_keys:
|
||||||
|
:param agg_default:
|
||||||
|
:param act_default:
|
||||||
|
:param response_default:
|
||||||
|
:param bias_default:
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes: (N, 5)
|
||||||
|
:param connections: (2, N, N)
|
||||||
|
:param new_node_key:
|
||||||
|
:param bias_mean:
|
||||||
|
:param bias_std:
|
||||||
|
:param bias_mutate_strength:
|
||||||
|
:param bias_mutate_rate:
|
||||||
|
:param bias_replace_rate:
|
||||||
|
:param response_mean:
|
||||||
|
:param response_std:
|
||||||
|
:param response_mutate_strength:
|
||||||
|
:param response_mutate_rate:
|
||||||
|
:param response_replace_rate:
|
||||||
|
:param weight_mean:
|
||||||
|
:param weight_std:
|
||||||
|
:param weight_mutate_strength:
|
||||||
|
:param weight_mutate_rate:
|
||||||
|
:param weight_replace_rate:
|
||||||
|
:param act_range:
|
||||||
|
:param act_replace_rate:
|
||||||
|
:param agg_range:
|
||||||
|
:param agg_replace_rate:
|
||||||
|
:param enabled_reverse_rate:
|
||||||
|
:param add_node_rate:
|
||||||
|
:param delete_node_rate:
|
||||||
|
:param add_connection_rate:
|
||||||
|
:param delete_connection_rate:
|
||||||
|
:param single_structure_mutate: a genome is structurally mutate at most once
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mutate_structure
|
||||||
|
def nothing(rk, n, c):
|
||||||
|
return n, c
|
||||||
|
|
||||||
|
def m_add_node(rk, n, c):
|
||||||
|
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||||
|
|
||||||
|
def m_delete_node(rk, n, c):
|
||||||
|
return mutate_delete_node(rk, n, c, input_keys, output_keys)
|
||||||
|
|
||||||
|
def m_add_connection(rk, n, c):
|
||||||
|
return mutate_add_connection(rk, n, c, input_keys, output_keys)
|
||||||
|
|
||||||
|
def m_delete_connection(rk, n, c):
|
||||||
|
return mutate_delete_connection(rk, n, c)
|
||||||
|
|
||||||
|
mutate_structure_li = [nothing, m_add_node, m_delete_node, m_add_connection, m_delete_connection]
|
||||||
|
|
||||||
|
if single_structure_mutate:
|
||||||
|
r1, r2, rand_key = jax.random.split(rand_key, 3)
|
||||||
|
d = jnp.maximum(1, add_node_rate + delete_node_rate + add_connection_rate + delete_connection_rate)
|
||||||
|
|
||||||
|
# shorten variable names for beauty
|
||||||
|
anr, dnr = add_node_rate / d, delete_node_rate / d
|
||||||
|
acr, dcr = add_connection_rate / d, delete_connection_rate / d
|
||||||
|
|
||||||
|
r = rand(r1)
|
||||||
|
branch = 0
|
||||||
|
branch = jnp.where(r <= anr, 1, branch)
|
||||||
|
branch = jnp.where((anr < r) & (r <= anr + dnr), 2, branch)
|
||||||
|
branch = jnp.where((anr + dnr < r) & (r <= anr + dnr + acr), 3, branch)
|
||||||
|
branch = jnp.where((anr + dnr + acr) < r & r <= (anr + dnr + acr + dcr), 4, branch)
|
||||||
|
nodes, connections = jax.lax.switch(branch, mutate_structure_li, (r2, nodes, connections))
|
||||||
|
else:
|
||||||
|
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||||
|
|
||||||
|
# mutate add node
|
||||||
|
aux_nodes, aux_connections = m_add_node(r1, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r1) < add_node_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r1) < add_node_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
# mutate delete node
|
||||||
|
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
# mutate add connection
|
||||||
|
aux_nodes, aux_connections = m_add_connection(r3, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
# mutate delete connection
|
||||||
|
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||||
|
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||||
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||||
|
weight_mean, weight_std, weight_mutate_strength,
|
||||||
|
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
|
||||||
|
agg_replace_rate, enabled_reverse_rate)
|
||||||
|
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_values(rand_key: Array,
|
||||||
|
nodes: Array,
|
||||||
|
connections: Array,
|
||||||
|
bias_mean: float = 0,
|
||||||
|
bias_std: float = 1,
|
||||||
|
bias_mutate_strength: float = 0.5,
|
||||||
|
bias_mutate_rate: float = 0.7,
|
||||||
|
bias_replace_rate: float = 0.1,
|
||||||
|
response_mean: float = 1.,
|
||||||
|
response_std: float = 0.,
|
||||||
|
response_mutate_strength: float = 0.,
|
||||||
|
response_mutate_rate: float = 0.,
|
||||||
|
response_replace_rate: float = 0.,
|
||||||
|
weight_mean: float = 0.,
|
||||||
|
weight_std: float = 1.,
|
||||||
|
weight_mutate_strength: float = 0.5,
|
||||||
|
weight_mutate_rate: float = 0.7,
|
||||||
|
weight_replace_rate: float = 0.1,
|
||||||
|
act_range: int = 5,
|
||||||
|
act_replace_rate: float = 0.1,
|
||||||
|
agg_range: int = 5,
|
||||||
|
agg_replace_rate: float = 0.1,
|
||||||
|
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
Mutate values of nodes and connections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rand_key: A random key for generating random values.
|
||||||
|
nodes: A 2D array representing nodes.
|
||||||
|
connections: A 3D array representing connections.
|
||||||
|
bias_mean: Mean of the bias values.
|
||||||
|
bias_std: Standard deviation of the bias values.
|
||||||
|
bias_mutate_strength: Strength of the bias mutation.
|
||||||
|
bias_mutate_rate: Rate of the bias mutation.
|
||||||
|
bias_replace_rate: Rate of the bias replacement.
|
||||||
|
response_mean: Mean of the response values.
|
||||||
|
response_std: Standard deviation of the response values.
|
||||||
|
response_mutate_strength: Strength of the response mutation.
|
||||||
|
response_mutate_rate: Rate of the response mutation.
|
||||||
|
response_replace_rate: Rate of the response replacement.
|
||||||
|
weight_mean: Mean of the weight values.
|
||||||
|
weight_std: Standard deviation of the weight values.
|
||||||
|
weight_mutate_strength: Strength of the weight mutation.
|
||||||
|
weight_mutate_rate: Rate of the weight mutation.
|
||||||
|
weight_replace_rate: Rate of the weight replacement.
|
||||||
|
act_range: Range of the activation function values.
|
||||||
|
act_replace_rate: Rate of the activation function replacement.
|
||||||
|
agg_range: Range of the aggregation function values.
|
||||||
|
agg_replace_rate: Rate of the aggregation function replacement.
|
||||||
|
enabled_reverse_rate: Rate of reversing enabled state of connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing mutated nodes and connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
k1, k2, k3, k4, k5, rand_key = jax.random.split(rand_key, num=6)
|
||||||
|
bias_new = mutate_float_values(k1, nodes[:, 1], bias_mean, bias_std,
|
||||||
|
bias_mutate_strength, bias_mutate_rate, bias_replace_rate)
|
||||||
|
response_new = mutate_float_values(k2, nodes[:, 2], response_mean, response_std,
|
||||||
|
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
||||||
|
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
|
||||||
|
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
||||||
|
act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate)
|
||||||
|
agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate)
|
||||||
|
|
||||||
|
# refactor enabled
|
||||||
|
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
|
||||||
|
enabled_new = connections[1, :, :] == 1
|
||||||
|
enabled_new = jnp.where(r < enabled_reverse_rate, ~enabled_new, enabled_new)
|
||||||
|
enabled_new = jnp.where(~jnp.isnan(connections[0, :, :]), enabled_new, jnp.nan)
|
||||||
|
|
||||||
|
nodes = nodes.at[:, 1].set(bias_new)
|
||||||
|
nodes = nodes.at[:, 2].set(response_new)
|
||||||
|
nodes = nodes.at[:, 3].set(act_new)
|
||||||
|
nodes = nodes.at[:, 4].set(agg_new)
|
||||||
|
connections = connections.at[0, :, :].set(weight_new)
|
||||||
|
connections = connections.at[1, :, :].set(enabled_new)
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||||
|
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||||
|
"""
|
||||||
|
Mutate float values of a given array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rand_key: A random key for generating random values.
|
||||||
|
old_vals: A 1D array of float values to be mutated.
|
||||||
|
mean: Mean of the values.
|
||||||
|
std: Standard deviation of the values.
|
||||||
|
mutate_strength: Strength of the mutation.
|
||||||
|
mutate_rate: Rate of the mutation.
|
||||||
|
replace_rate: Rate of the replacement.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mutated 1D array of float values.
|
||||||
|
"""
|
||||||
|
k1, k2, k3, rand_key = jax.random.split(rand_key, num=4)
|
||||||
|
noise = jax.random.normal(k1, old_vals.shape) * mutate_strength
|
||||||
|
replace = jax.random.normal(k2, old_vals.shape) * std + mean
|
||||||
|
r = jax.random.uniform(k3, old_vals.shape)
|
||||||
|
new_vals = old_vals
|
||||||
|
new_vals = jnp.where(r < mutate_rate, new_vals + noise, new_vals)
|
||||||
|
new_vals = jnp.where(
|
||||||
|
jnp.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
|
||||||
|
replace,
|
||||||
|
new_vals
|
||||||
|
)
|
||||||
|
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||||
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array:
|
||||||
|
"""
|
||||||
|
Mutate integer values (act, agg) of a given array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rand_key: A random key for generating random values.
|
||||||
|
old_vals: A 1D array of integer values to be mutated.
|
||||||
|
range: Range of the integer values.
|
||||||
|
replace_rate: Rate of the replacement.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mutated 1D array of integer values.
|
||||||
|
"""
|
||||||
|
k1, k2, rand_key = jax.random.split(rand_key, num=3)
|
||||||
|
replace_val = jax.random.randint(k1, old_vals.shape, 0, range)
|
||||||
|
r = jax.random.uniform(k2, old_vals.shape)
|
||||||
|
new_vals = old_vals
|
||||||
|
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
|
||||||
|
new_vals = jnp.where(~jnp.isnan(old_vals), new_vals, jnp.nan)
|
||||||
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||||
|
default_bias: float = 0, default_response: float = 1,
|
||||||
|
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
Randomly add a new node from splitting a connection.
|
||||||
|
:param rand_key:
|
||||||
|
:param new_node_key:
|
||||||
|
:param nodes:
|
||||||
|
:param connections:
|
||||||
|
:param default_bias:
|
||||||
|
:param default_response:
|
||||||
|
:param default_act:
|
||||||
|
:param default_agg:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# randomly choose a connection
|
||||||
|
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||||
|
|
||||||
|
# disable the connection
|
||||||
|
connections = connections.at[1, from_idx, to_idx].set(False)
|
||||||
|
|
||||||
|
# add a new node
|
||||||
|
nodes, connections = add_node(new_node_key, nodes, connections,
|
||||||
|
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||||
|
new_idx = fetch_first(nodes[:, 0] == new_node_key)
|
||||||
|
|
||||||
|
# add two new connections
|
||||||
|
weight = connections[0, from_idx, to_idx]
|
||||||
|
nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True)
|
||||||
|
nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True)
|
||||||
|
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||||
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
Randomly delete a node. Input and output nodes are not allowed to be deleted.
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes:
|
||||||
|
:param connections:
|
||||||
|
:param input_keys:
|
||||||
|
:param output_keys:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# randomly choose a node
|
||||||
|
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
|
||||||
|
allow_input_keys=False, allow_output_keys=False)
|
||||||
|
|
||||||
|
# delete the node
|
||||||
|
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||||
|
|
||||||
|
# delete connections
|
||||||
|
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||||
|
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||||
|
|
||||||
|
# check node_key valid
|
||||||
|
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
|
||||||
|
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
|
||||||
|
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||||
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
Randomly add a new connection. The output node is not allowed to be an input node. If in feedforward networks,
|
||||||
|
cycles are not allowed.
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes:
|
||||||
|
:param connections:
|
||||||
|
:param input_keys:
|
||||||
|
:param output_keys:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# randomly choose two nodes
|
||||||
|
k1, k2 = jax.random.split(rand_key, num=2)
|
||||||
|
from_key, from_idx = choice_node_key(k1, nodes, input_keys, output_keys,
|
||||||
|
allow_input_keys=True, allow_output_keys=True)
|
||||||
|
to_key, to_idx = choice_node_key(k2, nodes, input_keys, output_keys,
|
||||||
|
allow_input_keys=False, allow_output_keys=True)
|
||||||
|
|
||||||
|
def successful():
|
||||||
|
new_nodes, new_connections = add_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||||
|
return new_nodes, new_connections
|
||||||
|
|
||||||
|
def already_exist():
|
||||||
|
new_connections = connections.at[1, from_idx, to_idx].set(True)
|
||||||
|
return nodes, new_connections
|
||||||
|
|
||||||
|
def cycle():
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
is_already_exist = ~jnp.isnan(connections[0, from_idx, to_idx])
|
||||||
|
is_cycle = check_cycles(nodes, connections, from_idx, to_idx)
|
||||||
|
|
||||||
|
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||||
|
nodes, connections = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||||
|
"""
|
||||||
|
Randomly delete a connection.
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes:
|
||||||
|
:param connections:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# randomly choose a connection
|
||||||
|
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||||
|
nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||||
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||||
|
def choice_node_key(rand_key: Array, nodes: Array,
|
||||||
|
input_keys: Array, output_keys: Array,
|
||||||
|
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes:
|
||||||
|
:param input_keys:
|
||||||
|
:param output_keys:
|
||||||
|
:param allow_input_keys:
|
||||||
|
:param allow_output_keys:
|
||||||
|
:return: return its key and position(idx)
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
mask = ~jnp.isnan(node_keys)
|
||||||
|
|
||||||
|
if not allow_input_keys:
|
||||||
|
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
|
||||||
|
|
||||||
|
if not allow_output_keys:
|
||||||
|
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
|
||||||
|
|
||||||
|
idx = fetch_random(rand_key, mask)
|
||||||
|
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
|
||||||
|
return key, idx
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||||
|
"""
|
||||||
|
Randomly choose a connection key from the given connections.
|
||||||
|
:param rand_key:
|
||||||
|
:param nodes:
|
||||||
|
:param connection:
|
||||||
|
:return: from_key, to_key, from_idx, to_idx
|
||||||
|
"""
|
||||||
|
k1, k2 = jax.random.split(rand_key, num=2)
|
||||||
|
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
|
||||||
|
from_idx = fetch_random(k1, has_connections_row)
|
||||||
|
col = connection[0, from_idx, :]
|
||||||
|
to_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||||
|
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
|
||||||
|
return from_key, to_key, from_idx, to_idx
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def rand(rand_key):
|
||||||
|
return jax.random.uniform(rand_key, ())
|
||||||
134
algorithms/neat/genome/utils.py
Normal file
134
algorithms/neat/genome/utils.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, Array
|
||||||
|
from jax import jit
|
||||||
|
|
||||||
|
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def flatten_connections(keys, connections):
|
||||||
|
"""
|
||||||
|
flatten the (2, N, N) connections to (N * N, 4)
|
||||||
|
:param keys:
|
||||||
|
:param connections:
|
||||||
|
:return:
|
||||||
|
the first two columns are the index of the node
|
||||||
|
the 3rd column is the weight, and the 4th column is the enabled status
|
||||||
|
"""
|
||||||
|
indices_x, indices_y = jnp.meshgrid(keys, keys, indexing='ij')
|
||||||
|
indices = jnp.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
|
||||||
|
|
||||||
|
# make (2, N, N) to (N, N, 2)
|
||||||
|
con = jnp.transpose(connections, (1, 2, 0))
|
||||||
|
# make (N, N, 2) to (N * N, 2)
|
||||||
|
con = jnp.reshape(con, (-1, 2))
|
||||||
|
|
||||||
|
con = jnp.concatenate((indices, con), axis=1)
|
||||||
|
return con
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=['N'])
|
||||||
|
def unflatten_connections(N, cons):
|
||||||
|
"""
|
||||||
|
restore the (N * N, 4) connections to (2, N, N)
|
||||||
|
:param N:
|
||||||
|
:param cons:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cons = cons[:, 2:] # remove the indices
|
||||||
|
unflatten_cons = jnp.moveaxis(cons.reshape(N, N, 2), -1, 0)
|
||||||
|
return unflatten_cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def set_operation_analysis(ar1: Array, ar2: Array) -> Tuple[Array, Array, Array]:
|
||||||
|
"""
|
||||||
|
Analyze the intersection and union of two arrays by returning their sorted concatenation indices,
|
||||||
|
intersection mask, and union mask.
|
||||||
|
|
||||||
|
:param ar1: JAX array of shape (N, M)
|
||||||
|
First input array. Should have the same shape as ar2.
|
||||||
|
:param ar2: JAX array of shape (N, M)
|
||||||
|
Second input array. Should have the same shape as ar1.
|
||||||
|
:return: tuple of 3 arrays
|
||||||
|
- sorted_indices: Indices that would sort the concatenation of ar1 and ar2.
|
||||||
|
- intersect_mask: A boolean array indicating the positions of the common elements between ar1 and ar2
|
||||||
|
in the sorted concatenation.
|
||||||
|
- union_mask: A boolean array indicating the positions of the unique elements in the union of ar1 and ar2
|
||||||
|
in the sorted concatenation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
a = jnp.array([[1, 2], [3, 4], [5, 6]])
|
||||||
|
b = jnp.array([[1, 2], [7, 8], [9, 10]])
|
||||||
|
|
||||||
|
sorted_indices, intersect_mask, union_mask = set_operation_analysis(a, b)
|
||||||
|
|
||||||
|
sorted_indices -> array([0, 1, 2, 3, 4, 5])
|
||||||
|
intersect_mask -> array([True, False, False, False, False, False])
|
||||||
|
union_mask -> array([False, True, True, True, True, True])
|
||||||
|
"""
|
||||||
|
ar = jnp.concatenate((ar1, ar2), axis=0)
|
||||||
|
sorted_indices = jnp.lexsort(ar.T[::-1])
|
||||||
|
aux = ar[sorted_indices]
|
||||||
|
aux = jnp.concatenate((aux, jnp.full((1, ar1.shape[1]), jnp.nan)), axis=0)
|
||||||
|
nan_mask = jnp.any(jnp.isnan(aux), axis=1)
|
||||||
|
|
||||||
|
fr, sr = aux[:-1], aux[1:] # first row, second row
|
||||||
|
intersect_mask = jnp.all(fr == sr, axis=1) & ~nan_mask[:-1]
|
||||||
|
union_mask = jnp.any(fr != sr, axis=1) & ~nan_mask[:-1]
|
||||||
|
return sorted_indices, intersect_mask, union_mask
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def fetch_first(mask, default=I_INT) -> Array:
|
||||||
|
"""
|
||||||
|
fetch the first True index
|
||||||
|
:param mask: array of bool
|
||||||
|
:param default: the default value if no element satisfying the condition
|
||||||
|
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
|
||||||
|
example:
|
||||||
|
>>> a = jnp.array([1, 2, 3, 4, 5])
|
||||||
|
>>> fetch_first(a > 3)
|
||||||
|
3
|
||||||
|
>>> fetch_first(a > 30)
|
||||||
|
I_INT
|
||||||
|
"""
|
||||||
|
idx = jnp.argmax(mask)
|
||||||
|
return jnp.where(mask[idx], idx, default)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def fetch_last(mask, default=I_INT) -> Array:
|
||||||
|
"""
|
||||||
|
similar to fetch_first, but fetch the last True index
|
||||||
|
"""
|
||||||
|
reversed_idx = fetch_first(mask[::-1], default)
|
||||||
|
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||||
|
"""
|
||||||
|
similar to fetch_first, but fetch a random True index
|
||||||
|
"""
|
||||||
|
true_cnt = jnp.sum(mask)
|
||||||
|
cumsum = jnp.cumsum(mask)
|
||||||
|
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||||
|
return fetch_first(cumsum >= target, default)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
a = jnp.array([1, 2, 3, 4, 5])
|
||||||
|
print(fetch_first(a > 3))
|
||||||
|
print(fetch_first(a > 30))
|
||||||
|
|
||||||
|
print(fetch_last(a > 3))
|
||||||
|
print(fetch_last(a > 30))
|
||||||
|
|
||||||
|
rand_key = jax.random.PRNGKey(0)
|
||||||
|
for _ in range(100):
|
||||||
|
rand_key, _ = jax.random.split(rand_key)
|
||||||
|
print(fetch_random(rand_key, a > 0))
|
||||||
41
algorithms/neat/pipeline.py
Normal file
41
algorithms/neat/pipeline.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import jax
|
||||||
|
|
||||||
|
from .species import SpeciesController
|
||||||
|
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
Neat algorithm pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.N = config.basic.init_maximum_nodes
|
||||||
|
|
||||||
|
self.species_controller = SpeciesController(config)
|
||||||
|
self.initialize_func = create_initialize_function(config)
|
||||||
|
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||||
|
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||||
|
|
||||||
|
self.generation = 0
|
||||||
|
|
||||||
|
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||||
|
|
||||||
|
def ask(self, batch: bool):
|
||||||
|
"""
|
||||||
|
Create a forward function for the population.
|
||||||
|
:param batch:
|
||||||
|
:return:
|
||||||
|
Algorithm gives the population a forward function, then environment gives back the fitnesses.
|
||||||
|
"""
|
||||||
|
func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
|
||||||
|
batch=batch)
|
||||||
|
return func
|
||||||
|
|
||||||
|
def tell(self, fitnesses):
|
||||||
|
self.generation += 1
|
||||||
|
print(type(fitnesses), fitnesses)
|
||||||
|
self.species_controller.update_species_fitnesses(fitnesses)
|
||||||
|
|
||||||
|
|
||||||
190
algorithms/neat/species.py
Normal file
190
algorithms/neat/species.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from typing import List, Tuple, Dict
|
||||||
|
from itertools import count
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from .genome import distance
|
||||||
|
|
||||||
|
|
||||||
|
class Species(object):
|
||||||
|
|
||||||
|
def __init__(self, key, generation):
|
||||||
|
self.key = key
|
||||||
|
self.created = generation
|
||||||
|
self.last_improved = generation
|
||||||
|
self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections)
|
||||||
|
self.members: List[int] = [] # idx in pop_nodes, pop_connections
|
||||||
|
self.fitness = None
|
||||||
|
self.member_fitnesses = None
|
||||||
|
self.adjusted_fitness = None
|
||||||
|
self.fitness_history: List[float] = []
|
||||||
|
|
||||||
|
def update(self, representative, members):
|
||||||
|
self.representative = representative
|
||||||
|
self.members = members
|
||||||
|
|
||||||
|
def get_fitnesses(self, fitnesses):
|
||||||
|
return [fitnesses[m] for m in self.members]
|
||||||
|
|
||||||
|
|
||||||
|
class SpeciesController:
|
||||||
|
"""
|
||||||
|
A class to control the species
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.compatibility_threshold = self.config.neat.species.compatibility_threshold
|
||||||
|
self.species_elitism = self.config.neat.species.species_elitism
|
||||||
|
self.max_stagnation = self.config.neat.species.max_stagnation
|
||||||
|
|
||||||
|
self.species_idxer = count(0)
|
||||||
|
self.species: Dict[int, Species] = {} # species_id -> species
|
||||||
|
self.genome_to_species: Dict[int, int] = {}
|
||||||
|
|
||||||
|
self.o2m_distance_func = jax.vmap(distance, in_axes=(None, None, 0, 0)) # one to many
|
||||||
|
# self.o2o_distance_func = np_distance # one to one
|
||||||
|
self.o2o_distance_func = distance
|
||||||
|
|
||||||
|
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
||||||
|
"""
|
||||||
|
:param pop_nodes:
|
||||||
|
:param pop_connections:
|
||||||
|
:param generation: use to flag the created time of new species
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
unspeciated = np.full((pop_nodes.shape[0],), True, dtype=bool)
|
||||||
|
previous_species_list = list(self.species.keys())
|
||||||
|
|
||||||
|
# Find the best representatives for each existing species.
|
||||||
|
new_representatives = {}
|
||||||
|
new_members = {}
|
||||||
|
|
||||||
|
for sid, species in self.species.items():
|
||||||
|
# calculate the distance between the representative and the population
|
||||||
|
r_nodes, r_connections = species.representative
|
||||||
|
distances = self.o2m_distance_func(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||||
|
distances = jax.device_get(distances) # fetch the data from gpu
|
||||||
|
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||||
|
|
||||||
|
new_representatives[sid] = min_idx
|
||||||
|
new_members[sid] = [min_idx]
|
||||||
|
unspeciated[min_idx] = False
|
||||||
|
|
||||||
|
# Partition population into species based on genetic similarity.
|
||||||
|
|
||||||
|
# First, fast match the population to previous species
|
||||||
|
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||||
|
res_pop_distance = [
|
||||||
|
jax.device_get(
|
||||||
|
[
|
||||||
|
self.o2m_distance_func(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||||
|
for rid in rid_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
pop_res_distance = np.stack(res_pop_distance, axis=0).T
|
||||||
|
for i in range(pop_res_distance.shape[0]):
|
||||||
|
if not unspeciated[i]:
|
||||||
|
continue
|
||||||
|
min_idx = np.argmin(pop_res_distance[i])
|
||||||
|
min_val = pop_res_distance[i, min_idx]
|
||||||
|
if min_val <= self.compatibility_threshold:
|
||||||
|
species_id = previous_species_list[min_idx]
|
||||||
|
new_members[species_id].append(i)
|
||||||
|
unspeciated[i] = False
|
||||||
|
|
||||||
|
# Second, slowly match the lonely population to new-created species.
|
||||||
|
# lonely genome is proved to be not compatible with any previous species, so they only need to be compared with
|
||||||
|
# the new representatives.
|
||||||
|
new_species_list = []
|
||||||
|
for i in range(pop_nodes.shape[0]):
|
||||||
|
if not unspeciated[i]:
|
||||||
|
continue
|
||||||
|
unspeciated[i] = False
|
||||||
|
if len(new_representatives) != 0:
|
||||||
|
rid = [new_representatives[sid] for sid in new_representatives] # the representatives of new species
|
||||||
|
distances = [
|
||||||
|
self.o2o_distance_func(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||||
|
for r in rid
|
||||||
|
]
|
||||||
|
distances = np.array(distances)
|
||||||
|
min_idx = np.argmin(distances)
|
||||||
|
min_val = distances[min_idx]
|
||||||
|
if min_val <= self.compatibility_threshold:
|
||||||
|
species_id = new_species_list[min_idx]
|
||||||
|
new_members[species_id].append(i)
|
||||||
|
continue
|
||||||
|
# create a new species
|
||||||
|
species_id = next(self.species_idxer)
|
||||||
|
new_species_list.append(species_id)
|
||||||
|
new_representatives[species_id] = i
|
||||||
|
new_members[species_id] = [i]
|
||||||
|
|
||||||
|
assert np.all(~unspeciated)
|
||||||
|
# Update species collection based on new speciation.
|
||||||
|
self.genome_to_species = {}
|
||||||
|
for sid, rid in new_representatives.items():
|
||||||
|
s = self.species.get(sid)
|
||||||
|
if s is None:
|
||||||
|
s = Species(sid, generation)
|
||||||
|
self.species[sid] = s
|
||||||
|
|
||||||
|
members = new_members[sid]
|
||||||
|
for gid in members:
|
||||||
|
self.genome_to_species[gid] = sid
|
||||||
|
|
||||||
|
s.update((pop_nodes[rid], pop_connections[rid]), members)
|
||||||
|
|
||||||
|
def update_species_fitnesses(self, fitnesses):
|
||||||
|
"""
|
||||||
|
update the fitness of each species
|
||||||
|
:param fitnesses:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for sid, s in self.species.items():
|
||||||
|
# TODO: here use mean to measure the fitness of a species, but it may be other functions
|
||||||
|
s.member_fitnesses = s.get_fitnesses(fitnesses)
|
||||||
|
s.fitness = np.mean(s.member_fitnesses)
|
||||||
|
s.fitness_history.append(s.fitness)
|
||||||
|
s.adjusted_fitness = None
|
||||||
|
|
||||||
|
def stagnation(self, generation):
|
||||||
|
"""
|
||||||
|
code modified from neat-python!
|
||||||
|
:param generation:
|
||||||
|
:return: whether the species is stagnated
|
||||||
|
"""
|
||||||
|
species_data = []
|
||||||
|
for sid, s in self.species.items():
|
||||||
|
if s.fitness_history:
|
||||||
|
prev_fitness = max(s.fitness_history)
|
||||||
|
else:
|
||||||
|
prev_fitness = float('-inf')
|
||||||
|
|
||||||
|
if prev_fitness is None or s.fitness > prev_fitness:
|
||||||
|
s.last_improved = generation
|
||||||
|
|
||||||
|
species_data.append((sid, s))
|
||||||
|
|
||||||
|
# Sort in descending fitness order.
|
||||||
|
species_data.sort(key=lambda x: x[1].fitness, reverse=True)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for idx, (sid, s) in enumerate(species_data):
|
||||||
|
|
||||||
|
if idx < self.species_elitism: # elitism species never stagnate!
|
||||||
|
is_stagnant = False
|
||||||
|
else:
|
||||||
|
stagnant_time = generation - s.last_improved
|
||||||
|
is_stagnant = stagnant_time > self.max_stagnation
|
||||||
|
|
||||||
|
result.append((sid, s, is_stagnant))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def find_min_with_mask(arr: NDArray, mask: NDArray) -> int:
|
||||||
|
masked_arr = np.where(mask, arr, np.inf)
|
||||||
|
min_idx = np.argmin(masked_arr)
|
||||||
|
return min_idx
|
||||||
62
algorithms/neat/stagnation.py
Normal file
62
algorithms/neat/stagnation.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Code modified from NEAT-Python library
|
||||||
|
Keeps track of whether species are making progress and helps remove those which are not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Stagnation:
|
||||||
|
"""Keeps track of whether species are making progress and helps remove ones that are not."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def update(self, species_set, generation):
|
||||||
|
"""
|
||||||
|
Required interface method. Updates species fitness history information,
|
||||||
|
checking for ones that have not improved in max_stagnation generations,
|
||||||
|
and - unless it would result in the number of species dropping below the configured
|
||||||
|
species_elitism parameter if they were removed,
|
||||||
|
in which case the highest-fitness species are spared -
|
||||||
|
returns a list with stagnant species marked for removal.
|
||||||
|
"""
|
||||||
|
species_data = []
|
||||||
|
for sid, s in species_set.species.items():
|
||||||
|
if s.fitness_history:
|
||||||
|
prev_fitness = max(s.fitness_history)
|
||||||
|
else:
|
||||||
|
prev_fitness = float('-inf')
|
||||||
|
|
||||||
|
s.fitness = max(s.get_fitnesses())
|
||||||
|
s.fitness_history.append(s.fitness)
|
||||||
|
s.adjusted_fitness = None
|
||||||
|
if prev_fitness is None or s.fitness > prev_fitness:
|
||||||
|
s.last_improved = generation
|
||||||
|
|
||||||
|
species_data.append((sid, s))
|
||||||
|
|
||||||
|
# Sort in ascending fitness order.
|
||||||
|
species_data.sort(key=lambda x: x[1].fitness)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
species_fitnesses = []
|
||||||
|
num_non_stagnant = len(species_data)
|
||||||
|
for idx, (sid, s) in enumerate(species_data):
|
||||||
|
# Override stagnant state if marking this species as stagnant would
|
||||||
|
# result in the total number of species dropping below the limit.
|
||||||
|
# Because species are in ascending fitness order, less fit species
|
||||||
|
# will be marked as stagnant first.
|
||||||
|
stagnant_time = generation - s.last_improved
|
||||||
|
is_stagnant = False
|
||||||
|
if num_non_stagnant > self.config.stagnation.species_elitism:
|
||||||
|
is_stagnant = stagnant_time >= self.config.stagnation.max_stagnation
|
||||||
|
|
||||||
|
if (len(species_data) - idx) <= self.config.stagnation.species_elitism:
|
||||||
|
is_stagnant = False
|
||||||
|
|
||||||
|
if is_stagnant:
|
||||||
|
num_non_stagnant -= 1
|
||||||
|
|
||||||
|
result.append((sid, s, is_stagnant))
|
||||||
|
species_fitnesses.append(s.fitness)
|
||||||
|
|
||||||
|
return result
|
||||||
5
algorithms/numpy/__init__.py
Normal file
5
algorithms/numpy/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
numpy version of functions in genome
|
||||||
|
"""
|
||||||
|
from .distance import distance
|
||||||
|
from .utils import *
|
||||||
58
algorithms/numpy/distance.py
Normal file
58
algorithms/numpy/distance.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .utils import flatten_connections, set_operation_analysis
|
||||||
|
|
||||||
|
|
||||||
|
def distance(nodes1, connections1, nodes2, connections2):
|
||||||
|
node_distance = gene_distance(nodes1, nodes2, 'node')
|
||||||
|
|
||||||
|
# refactor connections
|
||||||
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
|
cons1 = flatten_connections(keys1, connections1)
|
||||||
|
cons2 = flatten_connections(keys2, connections2)
|
||||||
|
|
||||||
|
connection_distance = gene_distance(cons1, cons2, 'connection')
|
||||||
|
return node_distance + connection_distance
|
||||||
|
|
||||||
|
|
||||||
|
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||||
|
if gene_type == 'node':
|
||||||
|
keys1, keys2 = ar1[:, :1], ar2[:, :1]
|
||||||
|
else: # connection
|
||||||
|
keys1, keys2 = ar1[:, :2], ar2[:, :2]
|
||||||
|
|
||||||
|
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||||
|
nodes = np.concatenate((ar1, ar2), axis=0)
|
||||||
|
sorted_nodes = nodes[n_sorted_indices]
|
||||||
|
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||||
|
|
||||||
|
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
|
||||||
|
if gene_type == 'node':
|
||||||
|
node_distance = homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||||
|
else: # connection
|
||||||
|
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||||
|
|
||||||
|
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
|
||||||
|
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
|
||||||
|
|
||||||
|
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
|
||||||
|
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
|
||||||
|
|
||||||
|
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||||
|
return val / np.where(gene_cnt1 > gene_cnt2, gene_cnt1, gene_cnt2)
|
||||||
|
|
||||||
|
|
||||||
|
def homologous_node_distance(n1, n2):
|
||||||
|
d = 0
|
||||||
|
d += np.abs(n1[:, 1] - n2[:, 1]) # bias
|
||||||
|
d += np.abs(n1[:, 2] - n2[:, 2]) # response
|
||||||
|
d += n1[:, 3] != n2[:, 3] # activation
|
||||||
|
d += n1[:, 4] != n2[:, 4]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def homologous_connection_distance(c1, c2):
|
||||||
|
d = 0
|
||||||
|
d += np.abs(c1[:, 2] - c2[:, 2]) # weight
|
||||||
|
d += c1[:, 3] != c2[:, 3] # enable
|
||||||
|
return d
|
||||||
55
algorithms/numpy/utils.py
Normal file
55
algorithms/numpy/utils.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
I_INT = np.iinfo(np.int32).max # infinite int
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_connections(keys, connections):
|
||||||
|
indices_x, indices_y = np.meshgrid(keys, keys, indexing='ij')
|
||||||
|
indices = np.stack((indices_x, indices_y), axis=-1).reshape(-1, 2)
|
||||||
|
|
||||||
|
# make (2, N, N) to (N, N, 2)
|
||||||
|
con = np.transpose(connections, (1, 2, 0))
|
||||||
|
# make (N, N, 2) to (N * N, 2)
|
||||||
|
con = np.reshape(con, (-1, 2))
|
||||||
|
|
||||||
|
con = np.concatenate((indices, con), axis=1)
|
||||||
|
return con
|
||||||
|
|
||||||
|
|
||||||
|
def unflatten_connections(N, cons):
|
||||||
|
cons = cons[:, 2:] # remove the indices
|
||||||
|
unflatten_cons = np.moveaxis(cons.reshape(N, N, 2), -1, 0)
|
||||||
|
return unflatten_cons
|
||||||
|
|
||||||
|
|
||||||
|
def set_operation_analysis(ar1, ar2):
|
||||||
|
ar = np.concatenate((ar1, ar2), axis=0)
|
||||||
|
sorted_indices = np.lexsort(ar.T[::-1])
|
||||||
|
aux = ar[sorted_indices]
|
||||||
|
aux = np.concatenate((aux, np.full((1, ar1.shape[1]), np.nan)), axis=0)
|
||||||
|
nan_mask = np.any(np.isnan(aux), axis=1)
|
||||||
|
|
||||||
|
fr, sr = aux[:-1], aux[1:] # first row, second row
|
||||||
|
intersect_mask = np.all(fr == sr, axis=1) & ~nan_mask[:-1]
|
||||||
|
union_mask = np.any(fr != sr, axis=1) & ~nan_mask[:-1]
|
||||||
|
return sorted_indices, intersect_mask, union_mask
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_first(mask, default=I_INT):
|
||||||
|
idx = np.argmax(mask)
|
||||||
|
return np.where(mask[idx], idx, default)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_last(mask, default=I_INT):
|
||||||
|
reversed_idx = fetch_first(mask[::-1], default)
|
||||||
|
return np.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_random(rand_key, mask, default=I_INT):
|
||||||
|
"""
|
||||||
|
similar to fetch_first, but fetch a random True index
|
||||||
|
"""
|
||||||
|
true_cnt = np.sum(mask)
|
||||||
|
cumsum = np.cumsum(mask)
|
||||||
|
target = np.random.randint(rand_key, shape=(), minval=0, maxval=true_cnt + 1)
|
||||||
|
return fetch_first(cumsum >= target, default)
|
||||||
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
71
examples/genome_test.py
Normal file
71
examples/genome_test.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import jax.random
|
||||||
|
|
||||||
|
from utils import Configer
|
||||||
|
from algorithms.neat.genome.genome import *
|
||||||
|
|
||||||
|
from algorithms.neat.species import SpeciesController
|
||||||
|
from algorithms.neat.genome.forward import create_forward_function
|
||||||
|
from algorithms.neat.genome.mutate import create_mutate_function
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
N = 10
|
||||||
|
pop_nodes, pop_connections, input_idx, output_idx = initialize_genomes(10000, N, 2, 1,
|
||||||
|
default_act=9, default_agg=0)
|
||||||
|
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||||
|
# forward = create_forward_function(pop_nodes, pop_connections, 5, input_idx, output_idx, batch=True)
|
||||||
|
nodes, connections = pop_nodes[0], pop_connections[0]
|
||||||
|
|
||||||
|
forward = create_forward_function(pop_nodes, pop_connections, N, input_idx, output_idx, batch=True)
|
||||||
|
out = forward(inputs)
|
||||||
|
print(out.shape)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
config = Configer.load_config()
|
||||||
|
s_c = SpeciesController(config.neat)
|
||||||
|
s_c.speciate(pop_nodes, pop_connections, 0)
|
||||||
|
s_c.speciate(pop_nodes, pop_connections, 0)
|
||||||
|
print(s_c.genome_to_species)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
print(i)
|
||||||
|
s_c.speciate(pop_nodes, pop_connections, i)
|
||||||
|
print(time.time() - start)
|
||||||
|
|
||||||
|
seed = jax.random.PRNGKey(42)
|
||||||
|
mutate_func = create_mutate_function(config, input_idx, output_idx, batch=False)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
print(*mutate_func(seed, nodes, connections, 100), sep='\n')
|
||||||
|
|
||||||
|
randseeds = jax.random.split(seed, 10000)
|
||||||
|
new_node_keys = jax.random.randint(randseeds[0], minval=0, maxval=10000, shape=(10000,))
|
||||||
|
batch_mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True)
|
||||||
|
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
|
||||||
|
print(pop_nodes, pop_connections, sep='\n')
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
print(i)
|
||||||
|
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
|
||||||
|
print(time.time() - start)
|
||||||
|
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
nodes, connections = add_node(6, nodes, connections)
|
||||||
|
nodes, connections = add_node(7, nodes, connections)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
|
||||||
|
nodes, connections = add_connection(6, 7, nodes, connections)
|
||||||
|
nodes, connections = add_connection(0, 7, nodes, connections)
|
||||||
|
nodes, connections = add_connection(1, 7, nodes, connections)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
|
||||||
|
nodes, connections = delete_connection(6, 7, nodes, connections)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
|
||||||
|
nodes, connections = delete_node(6, nodes, connections)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
|
|
||||||
|
nodes, connections = delete_node(7, nodes, connections)
|
||||||
|
print(nodes, connections, sep='\n')
|
||||||
37
examples/jax_playground.py
Normal file
37
examples/jax_playground.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax import random
|
||||||
|
from jax import vmap, jit
|
||||||
|
|
||||||
|
|
||||||
|
def plus1(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
|
||||||
|
def minus1(x):
|
||||||
|
return x - 1
|
||||||
|
|
||||||
|
|
||||||
|
def func(rand_key, x):
|
||||||
|
r = jax.random.uniform(rand_key, shape=())
|
||||||
|
return jax.lax.cond(r > 0.5, plus1, minus1, x)
|
||||||
|
|
||||||
|
|
||||||
|
def func2(rand_key):
|
||||||
|
r = jax.random.uniform(rand_key, ())
|
||||||
|
if r < 0.3:
|
||||||
|
return 1
|
||||||
|
elif r < 0.5:
|
||||||
|
return 2
|
||||||
|
else:
|
||||||
|
return 3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
key = random.PRNGKey(0)
|
||||||
|
print(func(key, 0))
|
||||||
|
|
||||||
|
batch_func = vmap(jit(func))
|
||||||
|
keys = random.split(key, 100)
|
||||||
|
print(batch_func(keys, jnp.zeros(100)))
|
||||||
40
examples/xor.py
Normal file
40
examples/xor.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils import Configer
|
||||||
|
from algorithms.neat import Pipeline
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]])
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func: Callable) -> List[float]:
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses.tolist() # returns a list
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Configer.load_config()
|
||||||
|
pipeline = Pipeline(config)
|
||||||
|
forward_func = pipeline.ask(batch=True)
|
||||||
|
fitnesses = evaluate(forward_func)
|
||||||
|
pipeline.tell(fitnesses)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# for i in range(100):
|
||||||
|
# forward_func = pipeline.ask(batch=True)
|
||||||
|
# fitnesses = evaluate(forward_func)
|
||||||
|
# pipeline.tell(fitnesses)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .config import Configer
|
||||||
BIN
utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/config.cpython-39.pyc
Normal file
BIN
utils/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dotdict.cpython-39.pyc
Normal file
BIN
utils/__pycache__/dotdict.cpython-39.pyc
Normal file
Binary file not shown.
78
utils/config.py
Normal file
78
utils/config.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from .dotdict import DotDict
|
||||||
|
|
||||||
|
|
||||||
|
class Configer:
|
||||||
|
@classmethod
|
||||||
|
def __load_default_config(cls):
|
||||||
|
par_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
default_config_path = os.path.join(par_dir, "./default_config.json")
|
||||||
|
return cls.__load_config(default_config_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __load_config(cls, config_path):
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
text = "".join(f.readlines())
|
||||||
|
try:
|
||||||
|
j = json.loads(text)
|
||||||
|
except ValueError:
|
||||||
|
raise Exception("Invalid config")
|
||||||
|
return DotDict.from_dict(j, "root")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __check_redundant_config(cls, default_config, config):
|
||||||
|
for key in config:
|
||||||
|
if key not in default_config:
|
||||||
|
warnings.warn(f"Redundant config: {key} in {config.name}")
|
||||||
|
continue
|
||||||
|
if isinstance(default_config[key], DotDict):
|
||||||
|
cls.__check_redundant_config(default_config[key], config[key])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __complete_config(cls, default_config, config):
|
||||||
|
for key in default_config:
|
||||||
|
if key not in config:
|
||||||
|
config[key] = default_config[key]
|
||||||
|
continue
|
||||||
|
if isinstance(default_config[key], DotDict):
|
||||||
|
cls.__complete_config(default_config[key], config[key])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __decorate_config(cls, config):
|
||||||
|
if config.neat.gene.activation.options == 'all':
|
||||||
|
config.neat.gene.activation.options = [
|
||||||
|
"sigmoid", "tanh", "sin", "gauss", "relu", "elu", "lelu", "selu", "softplus", "identity", "clamped",
|
||||||
|
"inv", "log", "exp", "abs", "hat", "square", "cube"
|
||||||
|
]
|
||||||
|
if isinstance(config.neat.gene.activation.options, str):
|
||||||
|
config.neat.gene.activation.options = [config.neat.gene.activation.options]
|
||||||
|
|
||||||
|
if config.neat.gene.aggregation.options == 'all':
|
||||||
|
config.neat.gene.aggregation.options = ["product", "sum", "max", "min", "median", "mean"]
|
||||||
|
if isinstance(config.neat.gene.aggregation.options, str):
|
||||||
|
config.neat.gene.aggregation.options = [config.neat.gene.aggregation.options]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, config_path=None):
|
||||||
|
default_config = cls.__load_default_config()
|
||||||
|
if config_path is None:
|
||||||
|
config = DotDict("root")
|
||||||
|
elif not os.path.exists(config_path):
|
||||||
|
warnings.warn(f"config file {config_path} not exist!")
|
||||||
|
config = DotDict("root")
|
||||||
|
else:
|
||||||
|
config = cls.__load_config(config_path)
|
||||||
|
|
||||||
|
cls.__check_redundant_config(default_config, config)
|
||||||
|
cls.__complete_config(default_config, config)
|
||||||
|
cls.__decorate_config(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write_config(cls, config, write_path):
|
||||||
|
text = json.dumps(config, indent=2)
|
||||||
|
with open(write_path, "w") as f:
|
||||||
|
f.write(text)
|
||||||
108
utils/default_config.json
Normal file
108
utils/default_config.json
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
{
|
||||||
|
"basic": {
|
||||||
|
"num_inputs": 2,
|
||||||
|
"num_outputs": 1,
|
||||||
|
"init_maximum_nodes": 20,
|
||||||
|
"expands_coe": 1.5
|
||||||
|
},
|
||||||
|
"neat": {
|
||||||
|
"population": {
|
||||||
|
"fitness_criterion": "max",
|
||||||
|
"fitness_threshold": 43.9999,
|
||||||
|
"generation_limit": 100,
|
||||||
|
"pop_size": 1000,
|
||||||
|
"reset_on_extinction": "False"
|
||||||
|
},
|
||||||
|
"gene": {
|
||||||
|
"bias": {
|
||||||
|
"init_mean": 0.0,
|
||||||
|
"init_stdev": 1.0,
|
||||||
|
"max_value": 30.0,
|
||||||
|
"min_value": -30.0,
|
||||||
|
"mutate_power": 0.5,
|
||||||
|
"mutate_rate": 0.7,
|
||||||
|
"replace_rate": 0.1
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"init_mean": 1.0,
|
||||||
|
"init_stdev": 0.0,
|
||||||
|
"max_value": 30.0,
|
||||||
|
"min_value": -30.0,
|
||||||
|
"mutate_power": 0.0,
|
||||||
|
"mutate_rate": 0.0,
|
||||||
|
"replace_rate": 0.0
|
||||||
|
},
|
||||||
|
"activation": {
|
||||||
|
"default": "sigmoid",
|
||||||
|
"options": "sigmoid",
|
||||||
|
"mutate_rate": 0.01
|
||||||
|
},
|
||||||
|
"aggregation": {
|
||||||
|
"default": "sum",
|
||||||
|
"options": [
|
||||||
|
"product",
|
||||||
|
"sum",
|
||||||
|
"max",
|
||||||
|
"min",
|
||||||
|
"median",
|
||||||
|
"mean"
|
||||||
|
],
|
||||||
|
"mutate_rate": 0.01
|
||||||
|
},
|
||||||
|
"weight": {
|
||||||
|
"init_mean": 0.0,
|
||||||
|
"init_stdev": 1.0,
|
||||||
|
"max_value": 30.0,
|
||||||
|
"min_value": -30.0,
|
||||||
|
"mutate_power": 0.5,
|
||||||
|
"mutate_rate": 0.8,
|
||||||
|
"replace_rate": 0.1
|
||||||
|
},
|
||||||
|
"enabled": {
|
||||||
|
"mutate_rate": 0.01
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"genome": {
|
||||||
|
"compatibility_disjoint_coefficient": 1.0,
|
||||||
|
"compatibility_weight_coefficient": 0.5,
|
||||||
|
"feedforward": "True",
|
||||||
|
"single_structural_mutation": "False",
|
||||||
|
"conn_add_prob": 0.5,
|
||||||
|
"conn_delete_prob": 0.5,
|
||||||
|
"node_add_prob": 0.2,
|
||||||
|
"node_delete_prob": 0.2
|
||||||
|
},
|
||||||
|
"species": {
|
||||||
|
"compatibility_threshold": 3.5,
|
||||||
|
"species_fitness_func": "max",
|
||||||
|
"max_stagnation": 20,
|
||||||
|
"species_elitism": 2,
|
||||||
|
"genome_elitism": 2,
|
||||||
|
"survival_threshold": 0.2,
|
||||||
|
"min_species_size": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"hyperneat": {
|
||||||
|
"substrate": {
|
||||||
|
"type": "feedforward",
|
||||||
|
"layers": [
|
||||||
|
3,
|
||||||
|
10,
|
||||||
|
10,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"x_lim": [
|
||||||
|
-5,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"y_lim": [
|
||||||
|
-5,
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"threshold": 0.2,
|
||||||
|
"max_weight": 5.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"es-hyperneat": {
|
||||||
|
}
|
||||||
|
}
|
||||||
61
utils/dotdict.py
Normal file
61
utils/dotdict.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# DotDict For Config. Case Insensitive.
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __init__(self, name, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self["name"] = name
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
attr = attr.lower() # case insensitive
|
||||||
|
if attr in self:
|
||||||
|
return self[attr]
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'")
|
||||||
|
|
||||||
|
def __setattr__(self, attr, value):
|
||||||
|
attr = attr.lower() # case insensitive
|
||||||
|
if attr not in self:
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}-{self.name}' has no attribute '{attr}'")
|
||||||
|
self[attr] = value
|
||||||
|
|
||||||
|
def __delattr__(self, attr):
|
||||||
|
attr = attr.lower() # case insensitive
|
||||||
|
if attr in self:
|
||||||
|
del self[attr]
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"{self.__class__.__name__}-{self.name} object has no attribute '{attr}'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d, name):
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return d
|
||||||
|
|
||||||
|
dot_dict = cls(name)
|
||||||
|
for key, value in d.items():
|
||||||
|
key = key.lower() # case insensitive
|
||||||
|
if isinstance(value, dict):
|
||||||
|
dot_dict[key] = cls.from_dict(value, key)
|
||||||
|
else:
|
||||||
|
dot_dict[key] = value
|
||||||
|
if dot_dict[key] == "True": # Fuck! Json has no bool type!
|
||||||
|
dot_dict[key] = True
|
||||||
|
if dot_dict[key] == "False":
|
||||||
|
dot_dict[key] = False
|
||||||
|
if dot_dict[key] == "None":
|
||||||
|
dot_dict[key] = None
|
||||||
|
return dot_dict
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
nested_dict = {
|
||||||
|
"a": 1,
|
||||||
|
"b": {
|
||||||
|
"c": 2,
|
||||||
|
"ACDeef": {
|
||||||
|
"e": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dd = DotDict.from_dict(nested_dict, "root")
|
||||||
|
print(dd.b.acdeef.e) # 输出:3
|
||||||
Reference in New Issue
Block a user