Current Progress: After final design presentation

This commit is contained in:
wls2002
2023-06-19 15:17:56 +08:00
parent acedd67617
commit 5cbe3c14bb
34 changed files with 533 additions and 558 deletions

1
configs/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .configer import Configer

32
configs/activations.py Normal file
View File

@@ -0,0 +1,32 @@
from neat.genome.activations import *
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
def refactor_act(config):
config['activation_default'] = act_name2key[config['activation_default']]
config['activation_options'] = [
act_name2key[act_name] for act_name in config['activation_options']
]

20
configs/aggregations.py Normal file
View File

@@ -0,0 +1,20 @@
from neat.genome.aggregations import *
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
def refactor_agg(config):
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
config['aggregation_options'] = [
agg_name2key[act_name] for act_name in config['aggregation_options']
]

View File

@@ -2,8 +2,46 @@ import os
import warnings
import configparser
from .activations import refactor_act
from .aggregations import refactor_agg
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
"compatibility_disjoint",
"compatibility_weight",
"conn_add_prob",
"conn_add_trials",
"conn_delete_prob",
"node_add_prob",
"node_delete_prob",
"compatibility_threshold",
"bias_init_mean",
"bias_init_stdev",
"bias_mutate_power",
"bias_mutate_rate",
"bias_replace_rate",
"response_init_mean",
"response_init_stdev",
"response_mutate_power",
"response_mutate_rate",
"response_replace_rate",
"activation_default",
"activation_options",
"activation_replace_rate",
"aggregation_default",
"aggregation_options",
"aggregation_replace_rate",
"weight_init_mean",
"weight_init_stdev",
"weight_mutate_power",
"weight_mutate_rate",
"weight_replace_rate",
"enable_mutate_rate",
]
class Configer:
@classmethod
def __load_default_config(cls):
par_dir = os.path.dirname(os.path.abspath(__file__))
@@ -47,5 +85,13 @@ class Configer:
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
# cls.__decorate_config(config)
refactor_act(config)
refactor_agg(config)
return config
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}
return jit_config

View File

@@ -4,7 +4,7 @@ num_outputs = 1
init_maximum_nodes = 20
init_maximum_connections = 20
init_maximum_species = 10
expands_coe = 2
expands_coe = 2.0
forward_way = "pop_batch"
[population]

View File

@@ -1,45 +0,0 @@
import numpy as np
import jax
from utils import Configer
from neat import Pipeline
from neat import FunctionFactory
from problems import EnhanceLogic
import time
def evaluate(problem, func):
inputs = problem.ask_for_inputs()
pop_predict = jax.device_get(func(inputs))
# print(pop_predict)
fitnesses = []
for predict in pop_predict:
f = problem.evaluate_predict(predict)
fitnesses.append(f)
return np.array(fitnesses)
# @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main():
tic = time.time()
config = Configer.load_config()
problem = EnhanceLogic("xor", n=3)
problem.refactor_config(config)
function_factory = FunctionFactory(config)
evaluate_func = lambda func: evaluate(problem, func)
pipeline = Pipeline(config, function_factory, seed=33413)
print("start run")
pipeline.auto_run(evaluate_func)
total_time = time.time() - tic
compile_time = pipeline.function_factory.compile_time
total_it = pipeline.generation
mean_time_per_it = (total_time - compile_time) / total_it
evaluate_time = pipeline.evaluate_time
print(
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
if __name__ == '__main__':
main()

View File

@@ -1,37 +0,0 @@
import jax
import numpy as np
from neat import FunctionFactory
from neat.genome.debug.tools import check_array_valid
from utils import Configer
if __name__ == '__main__':
config = Configer.load_config()
function_factory = FunctionFactory(config, debug=True)
initialize_func = function_factory.create_initialize()
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
key = jax.random.PRNGKey(0)
new_node_idx = 100
while True:
key, subkey = jax.random.split(key)
mutate_keys = jax.random.split(subkey, len(pop_nodes))
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
new_node_idx += len(pop_nodes)
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
idx1 = np.random.permutation(len(pop_nodes))
idx2 = np.random.permutation(len(pop_nodes))
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
crossover_keys = jax.random.split(subkey, len(pop_nodes))
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
for i in range(len(pop_nodes)):
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
print(new_node_idx)

View File

@@ -1,59 +1,21 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from time import time
import numpy as np
from jax import jit
@jit
def jax_mutate(seed, x):
noise = jax.random.normal(seed, x.shape) * 0.1
return x + noise
def numpy_mutate(x):
noise = np.random.normal(size=x.shape) * 0.1
return x + noise
def jax_mutate_population(seed, pop_x):
seeds = jax.random.split(seed, len(pop_x))
func = vmap(jax_mutate, in_axes=(0, 0))
return func(seeds, pop_x)
def numpy_mutate_population(pop_x):
return np.stack([numpy_mutate(x) for x in pop_x])
def numpy_mutate_population_vmap(pop_x):
noise = np.random.normal(size=pop_x.shape) * 0.1
return pop_x + noise
from configs import Configer
from neat.pipeline_ import Pipeline
def main():
seed = jax.random.PRNGKey(0)
i = 10
while i < 200000:
pop_x = jnp.ones((i, 100, 100))
jax_pop_func = jit(jax_mutate_population).lower(seed, pop_x).compile()
config = Configer.load_config("xor.ini")
print(config)
pipeline = Pipeline(config)
tic = time()
res = jax.device_get(jax_pop_func(seed, pop_x))
jax_time = time() - tic
tic = time()
res = numpy_mutate_population(pop_x)
numpy_time = time() - tic
tic = time()
res = numpy_mutate_population_vmap(pop_x)
numpy_time_vmap = time() - tic
# print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Speedup: {numpy_time / jax_time:.4f}')
print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Numpy Vmap: {numpy_time_vmap:.4f}')
i = int(i * 1.3)
@jit
def f(x, jit_config):
return x + jit_config["bias_mutate_rate"]
if __name__ == '__main__':

2
examples/xor.ini Normal file
View File

@@ -0,0 +1,2 @@
[population]
fitness_threshold = -1e-2

View File

@@ -1,29 +1,43 @@
from neat import FunctionFactory
from utils import Configer
from neat import Pipeline
from problems import Xor
from typing import Callable, List
import time
import numpy as np
from configs import Configer
from neat import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
def evaluate(forward_func: Callable) -> List[float]:
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list
# @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
def main():
tic = time.time()
config = Configer.load_config()
config = Configer.load_config("xor.ini")
print(config)
assert False
problem = Xor()
problem.refactor_config(config)
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory, seed=6)
nodes, cons = pipeline.auto_run(problem.evaluate)
nodes, cons = pipeline.auto_run(evaluate)
print(nodes, cons)
total_time = time.time() - tic
compile_time = pipeline.function_factory.compile_time
total_it = pipeline.generation
mean_time_per_it = (total_time - compile_time) / total_it
evaluate_time = pipeline.evaluate_time
print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")

View File

@@ -1,2 +0,0 @@
from .pipeline import Pipeline
from .function_factory import FunctionFactory

View File

@@ -1,9 +0,0 @@
from .genome import expand, expand_single, initialize_genomes
from .forward import forward_single
from .activations import act_name2key
from .aggregations import agg_name2key
from .crossover import crossover
from .mutate import mutate
from .distance import distance
from .graph import topological_sort
from .utils import unflatten_connections

View File

@@ -104,31 +104,6 @@ def cube_act(z):
return z ** 3
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
@jit
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
@@ -137,4 +112,3 @@ def act(idx, z):
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)

View File

@@ -44,7 +44,6 @@ def maxabs_agg(z):
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
@@ -71,19 +70,6 @@ def mean_agg(z):
return mean_without_zeros
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
@@ -97,7 +83,6 @@ def agg(idx, z):
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)

76
neat/genome/crossover_.py Normal file
View File

@@ -0,0 +1,76 @@
from functools import partial
from typing import Tuple
import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
@jit
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
-> Tuple[Array, Array]:
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
:param randkey:
:param nodes1:
:param cons1:
:param nodes2:
:param cons2:
:return:
"""
randkey_1, randkey_2 = jax.random.split(randkey)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
nodes2 = align_array(keys1, keys2, nodes2, 'node')
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param gene_type:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if gene_type == 'connection':
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,3 +1,7 @@
"""
Calculate the distance between two genomes.
"""
from jax import jit, vmap, Array
from jax import numpy as jnp

105
neat/genome/distance_.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Calculate the distance between two genomes.
The calculation method is the same as the distance calculation in NEAT-python.
"""
from typing import Dict
from jax import jit, vmap, Array
from jax import numpy as jnp
from .utils import EMPTY_NODE, EMPTY_CON
@jit
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
"""
Calculate the distance between two genomes.
"""
nd = node_distance(nodes1, nodes2, jit_config) # node distance
cd = connection_distance(cons1, cons2, jit_config) # connection distance
return nd + cd
@jit
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
"""
Calculate the distance between two nodes.
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
"""
Calculate the distance between two connections.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):
return homologous_node_distance(b_n1, b_n2)
@vmap
def batch_homologous_connection_distance(b_c1, b_c2):
return homologous_connection_distance(b_c1, b_c2)
@jit
def homologous_node_distance(n1, n2):
d = 0
d += jnp.abs(n1[1] - n2[1]) # bias
d += jnp.abs(n1[2] - n2[2]) # response
d += n1[3] != n2[3] # activation
d += n1[4] != n2[4]
return d
@jit
def homologous_connection_distance(c1, c2):
d = 0
d += jnp.abs(c1[2] - c2[2]) # weight
d += c1[3] != c2[3] # enable
return d

View File

@@ -6,6 +6,7 @@ from .aggregations import agg
from .activations import act
from .utils import I_INT
# TODO: enabled information doesn't influence forward. That is wrong!
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,

View File

@@ -1,201 +0,0 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
too large to be represented by the current value of N and C.
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
(act), and aggregation function (agg).
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
Empty nodes or connections are represented using np.nan.
"""
from typing import Tuple, Dict
import jax
import numpy as np
from numpy.typing import NDArray
from jax import numpy as jnp
from jax import jit
from jax import Array
from .utils import fetch_first
def initialize_genomes(pop_size: int,
N: int,
C: int,
num_inputs: int,
num_outputs: int,
default_bias: float = 0.0,
default_response: float = 1.0,
default_act: int = 0,
default_agg: int = 0,
default_weight: float = 0.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
pop_size (int): Number of genomes to initialize.
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
num_inputs (int): Number of input nodes.
num_outputs (int): Number of output nodes.
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
Raises:
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
f"{num_inputs} and output_size: {num_outputs}!"
pop_nodes = np.full((pop_size, N, 5), np.nan)
pop_cons = np.full((pop_size, C, 4), np.nan)
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = default_bias
pop_nodes[:, output_idx, 2] = default_response
pop_nodes[:, output_idx, 3] = default_act
pop_nodes[:, output_idx, 4] = default_agg
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
pop_cons[:, :num_inputs * num_outputs, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the genome to accommodate more nodes.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return:
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes.
:param nodes: (N, 5)
:param cons: (2, N, N)
:param new_N:
:param new_C:
:return:
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
@jit
def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: Array, cons: Array, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
"""
add a new node to the genome.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
"""
delete a node from the genome. only delete the node, regardless of connections.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
use idx to delete a node from the genome. only delete the node, regardless of connections.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@jit
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
@jit
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
use idx to add a new connection to the genome.
"""
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
"""
delete a connection from the genome.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
"""
use idx to delete a connection from the genome.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

180
neat/genome/genome_.py Normal file
View File

@@ -0,0 +1,180 @@
"""
Vectorization of genome representation.
Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where:
nodes: [key, bias, response, act, agg]
connections: [in_key, out_key, weight, enable]
N: Maximum number of nodes in the network.
C: Maximum number of connections in the network.
"""
from typing import Tuple, Dict
import numpy as np
from numpy.typing import NDArray
from jax import jit, numpy as jnp
from .utils import fetch_first
def initialize_genomes(N: int,
C: int,
config: Dict) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.
Args:
N (int): Maximum number of nodes in the network.
C (int): Maximum number of connections in the network.
config (Dict): Configuration dictionary.
Returns:
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
"""
# Reserve one row for potential mutation adding an extra node
assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \
f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!"
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
pop_nodes[:, input_idx, 0] = input_idx
pop_nodes[:, output_idx, 0] = output_idx
pop_nodes[:, output_idx, 1] = config['bias_init_mean']
pop_nodes[:, output_idx, 2] = config['response_init_mean']
pop_nodes[:, output_idx, 3] = config['activation_default']
pop_nodes[:, output_idx, 4] = config['aggregation_default']
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
p = config['num_inputs'] * config['num_outputs']
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = config['weight_init_mean']
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons, input_idx, output_idx
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes or connections.
:param nodes: (N, 5)
:param cons: (C, 4)
:param new_N:
:param new_C:
:return: (new_N, 5), (new_C, 4)
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the population to accommodate more nodes or connections.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
@jit
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
"""
Count how many nodes and connections are in the genome.
"""
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
idx = fetch_first(jnp.isnan(exist_keys))
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
return nodes, cons
@jit
def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its key.
"""
node_keys = nodes[:, 0]
idx = fetch_first(node_keys == node_key)
return delete_node_by_idx(nodes, cons, idx)
@jit
def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a node from the genome. Only delete the node, regardless of connections.
Delete the node by its idx.
"""
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@jit
def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int,
weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
return nodes, cons
@jit
def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its input and output node keys.
"""
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
return delete_connection_by_idx(nodes, cons, idx)
@jit
def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
cons = cons.at[idx].set(np.nan)
return nodes, cons

View File

@@ -7,7 +7,7 @@ import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import fetch_first, I_INT
# from .configs import fetch_first, I_INT
from neat.genome.utils import fetch_first, I_INT

View File

@@ -32,9 +32,6 @@ def unflatten_connections(nodes, cons):
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
# (2, N, N), (2, N, N), (2, N, N)
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
return res
@@ -88,6 +85,7 @@ def argmin_with_mask(arr: Array, mask: Array) -> Array:
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])

27
neat/pipeline_.py Normal file
View File

@@ -0,0 +1,27 @@
import jax
from configs.configer import Configer
from .genome.genome_ import initialize_genomes
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
self.config = config # global config
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
self.N = self.config["init_maximum_nodes"]
self.C = self.config["init_maximum_connections"]
self.S = self.config["init_maximum_species"]
self.generation = 0
self.best_genome = None
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = initialize_genomes(self.N, self.C, self.config)
print(self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx, sep='\n')
print(self.jit_config)

View File

@@ -1,3 +0,0 @@
from .problem import Problem
from .function_fitting import *
from .gym import *

View File

@@ -1,5 +0,0 @@
from .function_fitting_problem import FunctionFittingProblem
from .xor import *
from .sin import *
from .diy import *
from .enhance_logic import *

View File

@@ -1,14 +0,0 @@
import numpy as np
from . import FunctionFittingProblem
class DIY(FunctionFittingProblem):
def __init__(self, func, size=100):
self.num_inputs = 1
self.num_outputs = 1
self.batch = size
self.inputs = np.linspace(0, 1, self.batch)[:, None]
self.target = func(self.inputs)
print(self.inputs, self.target)
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)

View File

@@ -1,54 +0,0 @@
"""
xor problem in multiple dimensions
"""
from itertools import product
import numpy as np
class EnhanceLogic:
def __init__(self, name="xor", n=2):
self.name = name
self.n = n
self.num_inputs = n
self.num_outputs = 1
self.batch = 2 ** n
self.forward_way = 'pop_batch'
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
if self.name == "xor":
self.outputs = np.sum(self.inputs, axis=1) % 2
elif self.name == "and":
self.outputs = np.all(self.inputs==1, axis=1)
elif self.name == "or":
self.outputs = np.any(self.inputs==1, axis=1)
else:
raise NotImplementedError("Only support xor, and, or")
self.outputs = self.outputs[:, np.newaxis]
def refactor_config(self, config):
config.basic.forward_way = self.forward_way
config.basic.num_inputs = self.num_inputs
config.basic.num_outputs = self.num_outputs
config.basic.problem_batch = self.batch
def ask_for_inputs(self):
return self.inputs
def evaluate_predict(self, predict):
# print((predict - self.outputs) ** 2)
return -np.mean((predict - self.outputs) ** 2)
def generate_permutations(n):
permutations = [list(i) for i in product([0, 1], repeat=n)]
return permutations
if __name__ == '__main__':
_ = EnhanceLogic(4)

View File

@@ -1,39 +0,0 @@
import numpy as np
import jax
from problems import Problem
class FunctionFittingProblem(Problem):
def __init__(self, num_inputs, num_outputs, batch, inputs, target, loss='MSE'):
self.forward_way = 'pop_batch'
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.batch = batch
self.inputs = inputs
self.target = target
self.loss = loss
super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch)
def evaluate(self, pop_batch_forward):
outs = pop_batch_forward(self.inputs)
outs = jax.device_get(outs)
fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2))
return fitnesses
def draw(self, batch_func):
outs = batch_func(self.inputs)
outs = jax.device_get(outs)
print(outs)
from matplotlib import pyplot as plt
plt.xlabel('x')
plt.ylabel('y')
plt.plot(self.inputs, self.target, color='red', label='target')
plt.plot(self.inputs, outs, color='blue', label='predict')
plt.legend()
plt.show()
def print(self, batch_func):
outs = batch_func(self.inputs)
outs = jax.device_get(outs)
print(outs)

View File

@@ -1,14 +0,0 @@
import numpy as np
from . import FunctionFittingProblem
class Sin(FunctionFittingProblem):
def __init__(self, size=100):
self.num_inputs = 1
self.num_outputs = 1
self.batch = size
self.inputs = np.linspace(0, 2 * np.pi, self.batch)[:, None]
self.target = np.sin(self.inputs)
print(self.inputs, self.target)
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)

View File

@@ -1,13 +0,0 @@
import numpy as np
from . import FunctionFittingProblem
class Xor(FunctionFittingProblem):
def __init__(self):
self.num_inputs = 2
self.num_outputs = 1
self.batch = 4
self.inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
self.target = np.array([[0], [1], [1], [0]], dtype=np.float32)
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)

View File

@@ -1,15 +0,0 @@
class Problem:
def __init__(self, forward_way, num_inputs, num_outputs, batch):
self.forward_way = forward_way
self.batch = batch
self.num_inputs = num_inputs
self.num_outputs = num_outputs
def refactor_config(self, config):
config.basic.forward_way = self.forward_way
config.basic.num_inputs = self.num_inputs
config.basic.num_outputs = self.num_outputs
config.basic.problem_batch = self.batch
def evaluate(self, batch_forward_func):
pass

View File

@@ -1 +0,0 @@
from .config import Configer