Current Progress: After final design presentation
This commit is contained in:
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .configer import Configer
|
||||||
32
configs/activations.py
Normal file
32
configs/activations.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from neat.genome.activations import *
|
||||||
|
|
||||||
|
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
|
||||||
|
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
|
||||||
|
|
||||||
|
act_name2key = {
|
||||||
|
'sigmoid': 0,
|
||||||
|
'tanh': 1,
|
||||||
|
'sin': 2,
|
||||||
|
'gauss': 3,
|
||||||
|
'relu': 4,
|
||||||
|
'elu': 5,
|
||||||
|
'lelu': 6,
|
||||||
|
'selu': 7,
|
||||||
|
'softplus': 8,
|
||||||
|
'identity': 9,
|
||||||
|
'clamped': 10,
|
||||||
|
'inv': 11,
|
||||||
|
'log': 12,
|
||||||
|
'exp': 13,
|
||||||
|
'abs': 14,
|
||||||
|
'hat': 15,
|
||||||
|
'square': 16,
|
||||||
|
'cube': 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def refactor_act(config):
|
||||||
|
config['activation_default'] = act_name2key[config['activation_default']]
|
||||||
|
config['activation_options'] = [
|
||||||
|
act_name2key[act_name] for act_name in config['activation_options']
|
||||||
|
]
|
||||||
20
configs/aggregations.py
Normal file
20
configs/aggregations.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from neat.genome.aggregations import *
|
||||||
|
|
||||||
|
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
||||||
|
|
||||||
|
agg_name2key = {
|
||||||
|
'sum': 0,
|
||||||
|
'product': 1,
|
||||||
|
'max': 2,
|
||||||
|
'min': 3,
|
||||||
|
'maxabs': 4,
|
||||||
|
'median': 5,
|
||||||
|
'mean': 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def refactor_agg(config):
|
||||||
|
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
|
||||||
|
config['aggregation_options'] = [
|
||||||
|
agg_name2key[act_name] for act_name in config['aggregation_options']
|
||||||
|
]
|
||||||
@@ -2,8 +2,46 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
from .activations import refactor_act
|
||||||
|
from .aggregations import refactor_agg
|
||||||
|
|
||||||
|
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||||
|
jit_config_keys = [
|
||||||
|
"compatibility_disjoint",
|
||||||
|
"compatibility_weight",
|
||||||
|
"conn_add_prob",
|
||||||
|
"conn_add_trials",
|
||||||
|
"conn_delete_prob",
|
||||||
|
"node_add_prob",
|
||||||
|
"node_delete_prob",
|
||||||
|
"compatibility_threshold",
|
||||||
|
"bias_init_mean",
|
||||||
|
"bias_init_stdev",
|
||||||
|
"bias_mutate_power",
|
||||||
|
"bias_mutate_rate",
|
||||||
|
"bias_replace_rate",
|
||||||
|
"response_init_mean",
|
||||||
|
"response_init_stdev",
|
||||||
|
"response_mutate_power",
|
||||||
|
"response_mutate_rate",
|
||||||
|
"response_replace_rate",
|
||||||
|
"activation_default",
|
||||||
|
"activation_options",
|
||||||
|
"activation_replace_rate",
|
||||||
|
"aggregation_default",
|
||||||
|
"aggregation_options",
|
||||||
|
"aggregation_replace_rate",
|
||||||
|
"weight_init_mean",
|
||||||
|
"weight_init_stdev",
|
||||||
|
"weight_mutate_power",
|
||||||
|
"weight_mutate_rate",
|
||||||
|
"weight_replace_rate",
|
||||||
|
"enable_mutate_rate",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Configer:
|
class Configer:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __load_default_config(cls):
|
def __load_default_config(cls):
|
||||||
par_dir = os.path.dirname(os.path.abspath(__file__))
|
par_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -47,5 +85,13 @@ class Configer:
|
|||||||
|
|
||||||
cls.__check_redundant_config(default_config, config)
|
cls.__check_redundant_config(default_config, config)
|
||||||
cls.__complete_config(default_config, config)
|
cls.__complete_config(default_config, config)
|
||||||
# cls.__decorate_config(config)
|
|
||||||
|
refactor_act(config)
|
||||||
|
refactor_agg(config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_jit_config(cls, config):
|
||||||
|
jit_config = {k: config[k] for k in jit_config_keys}
|
||||||
|
return jit_config
|
||||||
@@ -4,7 +4,7 @@ num_outputs = 1
|
|||||||
init_maximum_nodes = 20
|
init_maximum_nodes = 20
|
||||||
init_maximum_connections = 20
|
init_maximum_connections = 20
|
||||||
init_maximum_species = 10
|
init_maximum_species = 10
|
||||||
expands_coe = 2
|
expands_coe = 2.0
|
||||||
forward_way = "pop_batch"
|
forward_way = "pop_batch"
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import jax
|
|
||||||
from utils import Configer
|
|
||||||
from neat import Pipeline
|
|
||||||
from neat import FunctionFactory
|
|
||||||
from problems import EnhanceLogic
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(problem, func):
|
|
||||||
inputs = problem.ask_for_inputs()
|
|
||||||
pop_predict = jax.device_get(func(inputs))
|
|
||||||
# print(pop_predict)
|
|
||||||
fitnesses = []
|
|
||||||
for predict in pop_predict:
|
|
||||||
f = problem.evaluate_predict(predict)
|
|
||||||
fitnesses.append(f)
|
|
||||||
return np.array(fitnesses)
|
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
|
||||||
def main():
|
|
||||||
tic = time.time()
|
|
||||||
config = Configer.load_config()
|
|
||||||
problem = EnhanceLogic("xor", n=3)
|
|
||||||
problem.refactor_config(config)
|
|
||||||
function_factory = FunctionFactory(config)
|
|
||||||
evaluate_func = lambda func: evaluate(problem, func)
|
|
||||||
pipeline = Pipeline(config, function_factory, seed=33413)
|
|
||||||
print("start run")
|
|
||||||
pipeline.auto_run(evaluate_func)
|
|
||||||
|
|
||||||
total_time = time.time() - tic
|
|
||||||
compile_time = pipeline.function_factory.compile_time
|
|
||||||
total_it = pipeline.generation
|
|
||||||
mean_time_per_it = (total_time - compile_time) / total_it
|
|
||||||
evaluate_time = pipeline.evaluate_time
|
|
||||||
print(
|
|
||||||
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
|
||||||
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
from neat import FunctionFactory
|
|
||||||
from neat.genome.debug.tools import check_array_valid
|
|
||||||
from utils import Configer
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Configer.load_config()
|
|
||||||
function_factory = FunctionFactory(config, debug=True)
|
|
||||||
initialize_func = function_factory.create_initialize()
|
|
||||||
pop_nodes, pop_connections, input_idx, output_idx = initialize_func()
|
|
||||||
mutate_func = function_factory.create_mutate(pop_nodes.shape[1], pop_connections.shape[1])
|
|
||||||
crossover_func = function_factory.create_crossover(pop_nodes.shape[1], pop_connections.shape[1])
|
|
||||||
key = jax.random.PRNGKey(0)
|
|
||||||
new_node_idx = 100
|
|
||||||
while True:
|
|
||||||
key, subkey = jax.random.split(key)
|
|
||||||
mutate_keys = jax.random.split(subkey, len(pop_nodes))
|
|
||||||
new_nodes = np.arange(new_node_idx, new_node_idx + len(pop_nodes))
|
|
||||||
new_node_idx += len(pop_nodes)
|
|
||||||
pop_nodes, pop_connections = mutate_func(mutate_keys, pop_nodes, pop_connections, new_nodes)
|
|
||||||
pop_nodes, pop_connections = jax.device_get([pop_nodes, pop_connections])
|
|
||||||
idx1 = np.random.permutation(len(pop_nodes))
|
|
||||||
idx2 = np.random.permutation(len(pop_nodes))
|
|
||||||
|
|
||||||
n1, c1 = pop_nodes[idx1], pop_connections[idx1]
|
|
||||||
n2, c2 = pop_nodes[idx2], pop_connections[idx2]
|
|
||||||
crossover_keys = jax.random.split(subkey, len(pop_nodes))
|
|
||||||
|
|
||||||
pop_nodes, pop_connections = crossover_func(crossover_keys, n1, c1, n2, c2)
|
|
||||||
|
|
||||||
for i in range(len(pop_nodes)):
|
|
||||||
check_array_valid(pop_nodes[i], pop_connections[i], input_idx, output_idx)
|
|
||||||
|
|
||||||
print(new_node_idx)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,59 +1,21 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
from jax import jit
|
||||||
from jax import jit, vmap
|
|
||||||
from time import time
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
@jit
|
from neat.pipeline_ import Pipeline
|
||||||
def jax_mutate(seed, x):
|
|
||||||
noise = jax.random.normal(seed, x.shape) * 0.1
|
|
||||||
return x + noise
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_mutate(x):
|
|
||||||
noise = np.random.normal(size=x.shape) * 0.1
|
|
||||||
return x + noise
|
|
||||||
|
|
||||||
|
|
||||||
def jax_mutate_population(seed, pop_x):
|
|
||||||
seeds = jax.random.split(seed, len(pop_x))
|
|
||||||
func = vmap(jax_mutate, in_axes=(0, 0))
|
|
||||||
return func(seeds, pop_x)
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_mutate_population(pop_x):
|
|
||||||
return np.stack([numpy_mutate(x) for x in pop_x])
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_mutate_population_vmap(pop_x):
|
|
||||||
noise = np.random.normal(size=pop_x.shape) * 0.1
|
|
||||||
return pop_x + noise
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
seed = jax.random.PRNGKey(0)
|
config = Configer.load_config("xor.ini")
|
||||||
i = 10
|
print(config)
|
||||||
while i < 200000:
|
pipeline = Pipeline(config)
|
||||||
pop_x = jnp.ones((i, 100, 100))
|
|
||||||
jax_pop_func = jit(jax_mutate_population).lower(seed, pop_x).compile()
|
|
||||||
|
|
||||||
tic = time()
|
|
||||||
res = jax.device_get(jax_pop_func(seed, pop_x))
|
|
||||||
jax_time = time() - tic
|
|
||||||
|
|
||||||
tic = time()
|
@jit
|
||||||
res = numpy_mutate_population(pop_x)
|
def f(x, jit_config):
|
||||||
numpy_time = time() - tic
|
return x + jit_config["bias_mutate_rate"]
|
||||||
|
|
||||||
tic = time()
|
|
||||||
res = numpy_mutate_population_vmap(pop_x)
|
|
||||||
numpy_time_vmap = time() - tic
|
|
||||||
|
|
||||||
# print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Speedup: {numpy_time / jax_time:.4f}')
|
|
||||||
print(f'POP_SIZE: {i} | JAX: {jax_time:.4f} | Numpy: {numpy_time:.4f} | Numpy Vmap: {numpy_time_vmap:.4f}')
|
|
||||||
|
|
||||||
i = int(i * 1.3)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
2
examples/xor.ini
Normal file
2
examples/xor.ini
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
[population]
|
||||||
|
fitness_threshold = -1e-2
|
||||||
@@ -1,29 +1,43 @@
|
|||||||
from neat import FunctionFactory
|
from typing import Callable, List
|
||||||
from utils import Configer
|
|
||||||
from neat import Pipeline
|
|
||||||
from problems import Xor
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from configs import Configer
|
||||||
|
from neat import Pipeline
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]])
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func: Callable) -> List[float]:
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
# print(fitnesses)
|
||||||
|
return fitnesses.tolist() # returns a list
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
# @using_cprofile
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
config = Configer.load_config()
|
config = Configer.load_config("xor.ini")
|
||||||
print(config)
|
print(config)
|
||||||
assert False
|
|
||||||
problem = Xor()
|
|
||||||
problem.refactor_config(config)
|
|
||||||
function_factory = FunctionFactory(config)
|
function_factory = FunctionFactory(config)
|
||||||
pipeline = Pipeline(config, function_factory, seed=6)
|
pipeline = Pipeline(config, function_factory, seed=6)
|
||||||
nodes, cons = pipeline.auto_run(problem.evaluate)
|
nodes, cons = pipeline.auto_run(evaluate)
|
||||||
print(nodes, cons)
|
print(nodes, cons)
|
||||||
total_time = time.time() - tic
|
total_time = time.time() - tic
|
||||||
compile_time = pipeline.function_factory.compile_time
|
compile_time = pipeline.function_factory.compile_time
|
||||||
total_it = pipeline.generation
|
total_it = pipeline.generation
|
||||||
mean_time_per_it = (total_time - compile_time) / total_it
|
mean_time_per_it = (total_time - compile_time) / total_it
|
||||||
evaluate_time = pipeline.evaluate_time
|
evaluate_time = pipeline.evaluate_time
|
||||||
print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
print(
|
||||||
|
f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
||||||
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
from .pipeline import Pipeline
|
|
||||||
from .function_factory import FunctionFactory
|
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
from .genome import expand, expand_single, initialize_genomes
|
|
||||||
from .forward import forward_single
|
|
||||||
from .activations import act_name2key
|
|
||||||
from .aggregations import agg_name2key
|
|
||||||
from .crossover import crossover
|
|
||||||
from .mutate import mutate
|
|
||||||
from .distance import distance
|
|
||||||
from .graph import topological_sort
|
|
||||||
from .utils import unflatten_connections
|
|
||||||
@@ -104,31 +104,6 @@ def cube_act(z):
|
|||||||
return z ** 3
|
return z ** 3
|
||||||
|
|
||||||
|
|
||||||
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
|
|
||||||
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
|
|
||||||
|
|
||||||
act_name2key = {
|
|
||||||
'sigmoid': 0,
|
|
||||||
'tanh': 1,
|
|
||||||
'sin': 2,
|
|
||||||
'gauss': 3,
|
|
||||||
'relu': 4,
|
|
||||||
'elu': 5,
|
|
||||||
'lelu': 6,
|
|
||||||
'selu': 7,
|
|
||||||
'softplus': 8,
|
|
||||||
'identity': 9,
|
|
||||||
'clamped': 10,
|
|
||||||
'inv': 11,
|
|
||||||
'log': 12,
|
|
||||||
'exp': 13,
|
|
||||||
'abs': 14,
|
|
||||||
'hat': 15,
|
|
||||||
'square': 16,
|
|
||||||
'cube': 17,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def act(idx, z):
|
def act(idx, z):
|
||||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
@@ -137,4 +112,3 @@ def act(idx, z):
|
|||||||
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
||||||
|
|
||||||
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ def maxabs_agg(z):
|
|||||||
|
|
||||||
@jit
|
@jit
|
||||||
def median_agg(z):
|
def median_agg(z):
|
||||||
|
|
||||||
non_zero_mask = ~jnp.isnan(z)
|
non_zero_mask = ~jnp.isnan(z)
|
||||||
n = jnp.sum(non_zero_mask, axis=0)
|
n = jnp.sum(non_zero_mask, axis=0)
|
||||||
|
|
||||||
@@ -71,19 +70,6 @@ def mean_agg(z):
|
|||||||
return mean_without_zeros
|
return mean_without_zeros
|
||||||
|
|
||||||
|
|
||||||
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
|
||||||
|
|
||||||
agg_name2key = {
|
|
||||||
'sum': 0,
|
|
||||||
'product': 1,
|
|
||||||
'max': 2,
|
|
||||||
'min': 3,
|
|
||||||
'maxabs': 4,
|
|
||||||
'median': 5,
|
|
||||||
'mean': 6,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def agg(idx, z):
|
def agg(idx, z):
|
||||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
@@ -97,7 +83,6 @@ def agg(idx, z):
|
|||||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
||||||
|
|
||||||
|
|
||||||
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
||||||
|
|||||||
76
neat/genome/crossover_.py
Normal file
76
neat/genome/crossover_.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def crossover(randkey: Array, nodes1: Array, cons1: Array, nodes2: Array, cons2: Array) \
|
||||||
|
-> Tuple[Array, Array]:
|
||||||
|
"""
|
||||||
|
use genome1 and genome2 to generate a new genome
|
||||||
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
|
:param randkey:
|
||||||
|
:param nodes1:
|
||||||
|
:param cons1:
|
||||||
|
:param nodes2:
|
||||||
|
:param cons2:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
randkey_1, randkey_2 = jax.random.split(randkey)
|
||||||
|
|
||||||
|
# crossover nodes
|
||||||
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
|
nodes2 = align_array(keys1, keys2, nodes2, 'node')
|
||||||
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
|
||||||
|
|
||||||
|
# crossover connections
|
||||||
|
con_keys1, con_keys2 = cons1[:, :2], cons2[:, :2]
|
||||||
|
cons2 = align_array(con_keys1, con_keys2, cons2, 'connection')
|
||||||
|
new_cons = jnp.where(jnp.isnan(cons1) | jnp.isnan(cons2), cons1, crossover_gene(randkey_2, cons1, cons2))
|
||||||
|
|
||||||
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
|
# @partial(jit, static_argnames=['gene_type'])
|
||||||
|
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||||
|
"""
|
||||||
|
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||||
|
make ar2 align with ar1.
|
||||||
|
:param seq1:
|
||||||
|
:param seq2:
|
||||||
|
:param ar2:
|
||||||
|
:param gene_type:
|
||||||
|
:return:
|
||||||
|
align means to intersect part of ar2 will be at the same position as ar1,
|
||||||
|
non-intersect part of ar2 will be set to Nan
|
||||||
|
"""
|
||||||
|
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||||
|
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||||
|
|
||||||
|
if gene_type == 'connection':
|
||||||
|
mask = jnp.all(mask, axis=2)
|
||||||
|
|
||||||
|
intersect_mask = mask.any(axis=1)
|
||||||
|
idx = jnp.arange(0, len(seq1))
|
||||||
|
idx_fixed = jnp.dot(mask, idx)
|
||||||
|
|
||||||
|
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
|
||||||
|
|
||||||
|
return refactor_ar2
|
||||||
|
|
||||||
|
|
||||||
|
# @jit
|
||||||
|
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||||
|
"""
|
||||||
|
crossover two genes
|
||||||
|
:param rand_key:
|
||||||
|
:param g1:
|
||||||
|
:param g2:
|
||||||
|
:return:
|
||||||
|
only gene with the same key will be crossover, thus don't need to consider change key
|
||||||
|
"""
|
||||||
|
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||||
|
return jnp.where(r > 0.5, g1, g2)
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Calculate the distance between two genomes.
|
||||||
|
"""
|
||||||
|
|
||||||
from jax import jit, vmap, Array
|
from jax import jit, vmap, Array
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
|||||||
105
neat/genome/distance_.py
Normal file
105
neat/genome/distance_.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
Calculate the distance between two genomes.
|
||||||
|
The calculation method is the same as the distance calculation in NEAT-python.
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from jax import jit, vmap, Array
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
|
from .utils import EMPTY_NODE, EMPTY_CON
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def distance(nodes1: Array, cons1: Array, nodes2: Array, cons2: Array, jit_config: Dict) -> Array:
|
||||||
|
"""
|
||||||
|
Calculate the distance between two genomes.
|
||||||
|
"""
|
||||||
|
nd = node_distance(nodes1, nodes2, jit_config) # node distance
|
||||||
|
cd = connection_distance(cons1, cons2, jit_config) # connection distance
|
||||||
|
return nd + cd
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def node_distance(nodes1: Array, nodes2: Array, jit_config: Dict):
|
||||||
|
"""
|
||||||
|
Calculate the distance between two nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||||
|
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||||
|
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||||
|
|
||||||
|
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||||
|
keys = nodes[:, 0]
|
||||||
|
sorted_indices = jnp.argsort(keys, axis=0)
|
||||||
|
nodes = nodes[sorted_indices]
|
||||||
|
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||||
|
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||||
|
|
||||||
|
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||||
|
|
||||||
|
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||||
|
nd = batch_homologous_node_distance(fr, sr)
|
||||||
|
nd = jnp.where(jnp.isnan(nd), 0, nd)
|
||||||
|
homologous_distance = jnp.sum(nd * intersect_mask)
|
||||||
|
|
||||||
|
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||||
|
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||||
|
"""
|
||||||
|
Calculate the distance between two connections.
|
||||||
|
"""
|
||||||
|
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
|
||||||
|
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
|
||||||
|
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||||
|
|
||||||
|
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||||
|
keys = cons[:, :2]
|
||||||
|
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||||
|
cons = cons[sorted_indices]
|
||||||
|
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
|
||||||
|
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||||
|
|
||||||
|
# both genome has such connection
|
||||||
|
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||||
|
|
||||||
|
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||||
|
cd = batch_homologous_connection_distance(fr, sr)
|
||||||
|
cd = jnp.where(jnp.isnan(cd), 0, cd)
|
||||||
|
homologous_distance = jnp.sum(cd * intersect_mask)
|
||||||
|
|
||||||
|
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||||
|
|
||||||
|
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@vmap
|
||||||
|
def batch_homologous_node_distance(b_n1, b_n2):
|
||||||
|
return homologous_node_distance(b_n1, b_n2)
|
||||||
|
|
||||||
|
|
||||||
|
@vmap
|
||||||
|
def batch_homologous_connection_distance(b_c1, b_c2):
|
||||||
|
return homologous_connection_distance(b_c1, b_c2)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def homologous_node_distance(n1, n2):
|
||||||
|
d = 0
|
||||||
|
d += jnp.abs(n1[1] - n2[1]) # bias
|
||||||
|
d += jnp.abs(n1[2] - n2[2]) # response
|
||||||
|
d += n1[3] != n2[3] # activation
|
||||||
|
d += n1[4] != n2[4]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def homologous_connection_distance(c1, c2):
|
||||||
|
d = 0
|
||||||
|
d += jnp.abs(c1[2] - c2[2]) # weight
|
||||||
|
d += c1[3] != c2[3] # enable
|
||||||
|
return d
|
||||||
@@ -6,6 +6,7 @@ from .aggregations import agg
|
|||||||
from .activations import act
|
from .activations import act
|
||||||
from .utils import I_INT
|
from .utils import I_INT
|
||||||
|
|
||||||
|
|
||||||
# TODO: enabled information doesn't influence forward. That is wrong!
|
# TODO: enabled information doesn't influence forward. That is wrong!
|
||||||
@jit
|
@jit
|
||||||
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
"""
|
|
||||||
Vectorization of genome representation.
|
|
||||||
|
|
||||||
Utilizes Tuple[nodes: Array, connections: Array] to encode the genome, where:
|
|
||||||
|
|
||||||
1. N, C are pre-set values that determines the maximum number of nodes and connections in the network, and will increase if the genome becomes
|
|
||||||
too large to be represented by the current value of N and C.
|
|
||||||
2. nodes is an array of shape (N, 5), dtype=float, with columns corresponding to: key, bias, response, activation function
|
|
||||||
(act), and aggregation function (agg).
|
|
||||||
3. connections is an array of shape (C, 4), dtype=float, with columns corresponding to: i_key, o_key, weight, enabled.
|
|
||||||
Empty nodes or connections are represented using np.nan.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Tuple, Dict
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
from jax import numpy as jnp
|
|
||||||
from jax import jit
|
|
||||||
from jax import Array
|
|
||||||
|
|
||||||
from .utils import fetch_first
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_genomes(pop_size: int,
|
|
||||||
N: int,
|
|
||||||
C: int,
|
|
||||||
num_inputs: int,
|
|
||||||
num_outputs: int,
|
|
||||||
default_bias: float = 0.0,
|
|
||||||
default_response: float = 1.0,
|
|
||||||
default_act: int = 0,
|
|
||||||
default_agg: int = 0,
|
|
||||||
default_weight: float = 0.0) \
|
|
||||||
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
|
|
||||||
"""
|
|
||||||
Initialize genomes with default values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pop_size (int): Number of genomes to initialize.
|
|
||||||
N (int): Maximum number of nodes in the network.
|
|
||||||
C (int): Maximum number of connections in the network.
|
|
||||||
num_inputs (int): Number of input nodes.
|
|
||||||
num_outputs (int): Number of output nodes.
|
|
||||||
default_bias (float, optional): Default bias value for output nodes. Defaults to 0.0.
|
|
||||||
default_response (float, optional): Default response value for output nodes. Defaults to 1.0.
|
|
||||||
default_act (int, optional): Default activation function index for output nodes. Defaults to 1.
|
|
||||||
default_agg (int, optional): Default aggregation function index for output nodes. Defaults to 0.
|
|
||||||
default_weight (float, optional): Default weight value for connections. Defaults to 0.0.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the sum of num_inputs, num_outputs, and 1 is greater than N.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
|
|
||||||
"""
|
|
||||||
# Reserve one row for potential mutation adding an extra node
|
|
||||||
assert num_inputs + num_outputs + 1 <= N, f"Too small N: {N} for input_size: " \
|
|
||||||
f"{num_inputs} and output_size: {num_outputs}!"
|
|
||||||
assert num_inputs * num_outputs + 1 <= C, f"Too small C: {C} for input_size: " \
|
|
||||||
f"{num_inputs} and output_size: {num_outputs}!"
|
|
||||||
|
|
||||||
pop_nodes = np.full((pop_size, N, 5), np.nan)
|
|
||||||
pop_cons = np.full((pop_size, C, 4), np.nan)
|
|
||||||
input_idx = np.arange(num_inputs)
|
|
||||||
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
|
||||||
|
|
||||||
pop_nodes[:, input_idx, 0] = input_idx
|
|
||||||
pop_nodes[:, output_idx, 0] = output_idx
|
|
||||||
|
|
||||||
pop_nodes[:, output_idx, 1] = default_bias
|
|
||||||
pop_nodes[:, output_idx, 2] = default_response
|
|
||||||
pop_nodes[:, output_idx, 3] = default_act
|
|
||||||
pop_nodes[:, output_idx, 4] = default_agg
|
|
||||||
|
|
||||||
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
|
||||||
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
|
||||||
|
|
||||||
pop_cons[:, :num_inputs * num_outputs, 0] = grid_a
|
|
||||||
pop_cons[:, :num_inputs * num_outputs, 1] = grid_b
|
|
||||||
pop_cons[:, :num_inputs * num_outputs, 2] = default_weight
|
|
||||||
pop_cons[:, :num_inputs * num_outputs, 3] = 1
|
|
||||||
|
|
||||||
return pop_nodes, pop_cons, input_idx, output_idx
|
|
||||||
|
|
||||||
|
|
||||||
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
|
||||||
"""
|
|
||||||
Expand the genome to accommodate more nodes.
|
|
||||||
:param pop_nodes: (pop_size, N, 5)
|
|
||||||
:param pop_cons: (pop_size, C, 4)
|
|
||||||
:param new_N:
|
|
||||||
:param new_C:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
|
||||||
|
|
||||||
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
|
||||||
new_pop_nodes[:, :old_N, :] = pop_nodes
|
|
||||||
|
|
||||||
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
|
||||||
new_pop_cons[:, :old_C, :] = pop_cons
|
|
||||||
|
|
||||||
return new_pop_nodes, new_pop_cons
|
|
||||||
|
|
||||||
|
|
||||||
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
|
||||||
"""
|
|
||||||
Expand a single genome to accommodate more nodes.
|
|
||||||
:param nodes: (N, 5)
|
|
||||||
:param cons: (2, N, N)
|
|
||||||
:param new_N:
|
|
||||||
:param new_C:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
old_N, old_C = nodes.shape[0], cons.shape[0]
|
|
||||||
new_nodes = np.full((new_N, 5), np.nan)
|
|
||||||
new_nodes[:old_N, :] = nodes
|
|
||||||
|
|
||||||
new_cons = np.full((new_C, 4), np.nan)
|
|
||||||
new_cons[:old_C, :] = cons
|
|
||||||
|
|
||||||
return new_nodes, new_cons
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def count(nodes, cons):
|
|
||||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
|
||||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
|
||||||
return node_cnt, cons_cnt
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def add_node(nodes: Array, cons: Array, new_key: int,
|
|
||||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
add a new node to the genome.
|
|
||||||
"""
|
|
||||||
exist_keys = nodes[:, 0]
|
|
||||||
idx = fetch_first(jnp.isnan(exist_keys))
|
|
||||||
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
|
|
||||||
return nodes, cons
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def delete_node(nodes: Array, cons: Array, node_key: int) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
delete a node from the genome. only delete the node, regardless of connections.
|
|
||||||
"""
|
|
||||||
node_keys = nodes[:, 0]
|
|
||||||
idx = fetch_first(node_keys == node_key)
|
|
||||||
return delete_node_by_idx(nodes, cons, idx)
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
use idx to delete a node from the genome. only delete the node, regardless of connections.
|
|
||||||
"""
|
|
||||||
nodes = nodes.at[idx].set(np.nan)
|
|
||||||
return nodes, cons
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
|
|
||||||
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
add a new connection to the genome.
|
|
||||||
"""
|
|
||||||
con_keys = cons[:, 0]
|
|
||||||
idx = fetch_first(jnp.isnan(con_keys))
|
|
||||||
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def add_connection_by_idx(nodes: Array, cons: Array, idx: int, i_key: int, o_key: int,
|
|
||||||
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
use idx to add a new connection to the genome.
|
|
||||||
"""
|
|
||||||
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
|
|
||||||
return nodes, cons
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def delete_connection(nodes: Array, cons: Array, i_key: int, o_key: int) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
delete a connection from the genome.
|
|
||||||
"""
|
|
||||||
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
|
||||||
return delete_connection_by_idx(nodes, cons, idx)
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def delete_connection_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
use idx to delete a connection from the genome.
|
|
||||||
"""
|
|
||||||
cons = cons.at[idx].set(np.nan)
|
|
||||||
return nodes, cons
|
|
||||||
180
neat/genome/genome_.py
Normal file
180
neat/genome/genome_.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
Vectorization of genome representation.
|
||||||
|
|
||||||
|
Utilizes Tuple[nodes: Array(N, 5), connections: Array(C, 4)] to encode the genome, where:
|
||||||
|
nodes: [key, bias, response, act, agg]
|
||||||
|
connections: [in_key, out_key, weight, enable]
|
||||||
|
N: Maximum number of nodes in the network.
|
||||||
|
C: Maximum number of connections in the network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from jax import jit, numpy as jnp
|
||||||
|
|
||||||
|
from .utils import fetch_first
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_genomes(N: int,
|
||||||
|
C: int,
|
||||||
|
config: Dict) \
|
||||||
|
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Initialize genomes with default values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
N (int): Maximum number of nodes in the network.
|
||||||
|
C (int): Maximum number of connections in the network.
|
||||||
|
config (Dict): Configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[NDArray, NDArray, NDArray, NDArray]: pop_nodes, pop_connections, input_idx, and output_idx arrays.
|
||||||
|
"""
|
||||||
|
# Reserve one row for potential mutation adding an extra node
|
||||||
|
assert config['num_inputs'] + config['num_outputs'] + 1 <= N, \
|
||||||
|
f"Too small N: {N} for input_size: {config['num_inputs']} and output_size: {config['num_inputs']}!"
|
||||||
|
|
||||||
|
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
|
||||||
|
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
|
||||||
|
|
||||||
|
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
|
||||||
|
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
|
||||||
|
input_idx = np.arange(config['num_inputs'])
|
||||||
|
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||||
|
|
||||||
|
pop_nodes[:, input_idx, 0] = input_idx
|
||||||
|
pop_nodes[:, output_idx, 0] = output_idx
|
||||||
|
|
||||||
|
pop_nodes[:, output_idx, 1] = config['bias_init_mean']
|
||||||
|
pop_nodes[:, output_idx, 2] = config['response_init_mean']
|
||||||
|
pop_nodes[:, output_idx, 3] = config['activation_default']
|
||||||
|
pop_nodes[:, output_idx, 4] = config['aggregation_default']
|
||||||
|
|
||||||
|
grid_a, grid_b = np.meshgrid(input_idx, output_idx)
|
||||||
|
grid_a, grid_b = grid_a.flatten(), grid_b.flatten()
|
||||||
|
|
||||||
|
p = config['num_inputs'] * config['num_outputs']
|
||||||
|
pop_cons[:, :p, 0] = grid_a
|
||||||
|
pop_cons[:, :p, 1] = grid_b
|
||||||
|
pop_cons[:, :p, 2] = config['weight_init_mean']
|
||||||
|
pop_cons[:, :p, 3] = 1
|
||||||
|
|
||||||
|
return pop_nodes, pop_cons, input_idx, output_idx
|
||||||
|
|
||||||
|
|
||||||
|
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Expand a single genome to accommodate more nodes or connections.
|
||||||
|
:param nodes: (N, 5)
|
||||||
|
:param cons: (C, 4)
|
||||||
|
:param new_N:
|
||||||
|
:param new_C:
|
||||||
|
:return: (new_N, 5), (new_C, 4)
|
||||||
|
"""
|
||||||
|
old_N, old_C = nodes.shape[0], cons.shape[0]
|
||||||
|
new_nodes = np.full((new_N, 5), np.nan)
|
||||||
|
new_nodes[:old_N, :] = nodes
|
||||||
|
|
||||||
|
new_cons = np.full((new_C, 4), np.nan)
|
||||||
|
new_cons[:old_C, :] = cons
|
||||||
|
|
||||||
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
|
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Expand the population to accommodate more nodes or connections.
|
||||||
|
:param pop_nodes: (pop_size, N, 5)
|
||||||
|
:param pop_cons: (pop_size, C, 4)
|
||||||
|
:param new_N:
|
||||||
|
:param new_C:
|
||||||
|
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
|
||||||
|
"""
|
||||||
|
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
||||||
|
|
||||||
|
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||||
|
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||||
|
|
||||||
|
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
||||||
|
new_pop_cons[:, :old_C, :] = pop_cons
|
||||||
|
|
||||||
|
return new_pop_nodes, new_pop_cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Count how many nodes and connections are in the genome.
|
||||||
|
"""
|
||||||
|
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||||
|
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||||
|
return node_cnt, cons_cnt
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
|
||||||
|
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Add a new node to the genome.
|
||||||
|
The new node will place at the first NaN row.
|
||||||
|
"""
|
||||||
|
exist_keys = nodes[:, 0]
|
||||||
|
idx = fetch_first(jnp.isnan(exist_keys))
|
||||||
|
nodes = nodes.at[idx].set(jnp.array([new_key, bias, response, act, agg]))
|
||||||
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_node(nodes: NDArray, cons: NDArray, node_key: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||||
|
Delete the node by its key.
|
||||||
|
"""
|
||||||
|
node_keys = nodes[:, 0]
|
||||||
|
idx = fetch_first(node_keys == node_key)
|
||||||
|
return delete_node_by_idx(nodes, cons, idx)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_node_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Delete a node from the genome. Only delete the node, regardless of connections.
|
||||||
|
Delete the node by its idx.
|
||||||
|
"""
|
||||||
|
nodes = nodes.at[idx].set(np.nan)
|
||||||
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def add_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int,
|
||||||
|
weight: float = 1.0, enabled: bool = True) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Add a new connection to the genome.
|
||||||
|
The new connection will place at the first NaN row.
|
||||||
|
"""
|
||||||
|
con_keys = cons[:, 0]
|
||||||
|
idx = fetch_first(jnp.isnan(con_keys))
|
||||||
|
cons = cons.at[idx].set(jnp.array([i_key, o_key, weight, enabled]))
|
||||||
|
return nodes, cons
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_connection(nodes: NDArray, cons: NDArray, i_key: int, o_key: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Delete a connection from the genome.
|
||||||
|
Delete the connection by its input and output node keys.
|
||||||
|
"""
|
||||||
|
idx = fetch_first((cons[:, 0] == i_key) & (cons[:, 1] == o_key))
|
||||||
|
return delete_connection_by_idx(nodes, cons, idx)
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def delete_connection_by_idx(nodes: NDArray, cons: NDArray, idx: int) -> Tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Delete a connection from the genome.
|
||||||
|
Delete the connection by its idx.
|
||||||
|
"""
|
||||||
|
cons = cons.at[idx].set(np.nan)
|
||||||
|
return nodes, cons
|
||||||
@@ -7,7 +7,7 @@ import jax
|
|||||||
from jax import jit, vmap, Array
|
from jax import jit, vmap, Array
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
# from .utils import fetch_first, I_INT
|
# from .configs import fetch_first, I_INT
|
||||||
from neat.genome.utils import fetch_first, I_INT
|
from neat.genome.utils import fetch_first, I_INT
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ def unflatten_connections(nodes, cons):
|
|||||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||||
|
|
||||||
# (2, N, N), (2, N, N), (2, N, N)
|
|
||||||
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@@ -88,6 +85,7 @@ def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
|||||||
min_idx = jnp.argmin(masked_arr)
|
min_idx = jnp.argmin(masked_arr)
|
||||||
return min_idx
|
return min_idx
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
a = jnp.array([1, 2, 3, 4, 5])
|
a = jnp.array([1, 2, 3, 4, 5])
|
||||||
|
|||||||
27
neat/pipeline_.py
Normal file
27
neat/pipeline_.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import jax
|
||||||
|
|
||||||
|
from configs.configer import Configer
|
||||||
|
from .genome.genome_ import initialize_genomes
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
Neat algorithm pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, seed=42):
|
||||||
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
|
|
||||||
|
self.config = config # global config
|
||||||
|
self.jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||||
|
self.N = self.config["init_maximum_nodes"]
|
||||||
|
self.C = self.config["init_maximum_connections"]
|
||||||
|
self.S = self.config["init_maximum_species"]
|
||||||
|
|
||||||
|
self.generation = 0
|
||||||
|
self.best_genome = None
|
||||||
|
|
||||||
|
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = initialize_genomes(self.N, self.C, self.config)
|
||||||
|
|
||||||
|
print(self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx, sep='\n')
|
||||||
|
print(self.jit_config)
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .problem import Problem
|
|
||||||
from .function_fitting import *
|
|
||||||
from .gym import *
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from .function_fitting_problem import FunctionFittingProblem
|
|
||||||
from .xor import *
|
|
||||||
from .sin import *
|
|
||||||
from .diy import *
|
|
||||||
from .enhance_logic import *
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
from . import FunctionFittingProblem
|
|
||||||
|
|
||||||
|
|
||||||
class DIY(FunctionFittingProblem):
|
|
||||||
def __init__(self, func, size=100):
|
|
||||||
self.num_inputs = 1
|
|
||||||
self.num_outputs = 1
|
|
||||||
self.batch = size
|
|
||||||
self.inputs = np.linspace(0, 1, self.batch)[:, None]
|
|
||||||
self.target = func(self.inputs)
|
|
||||||
print(self.inputs, self.target)
|
|
||||||
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
"""
|
|
||||||
xor problem in multiple dimensions
|
|
||||||
"""
|
|
||||||
|
|
||||||
from itertools import product
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceLogic:
|
|
||||||
def __init__(self, name="xor", n=2):
|
|
||||||
self.name = name
|
|
||||||
self.n = n
|
|
||||||
self.num_inputs = n
|
|
||||||
self.num_outputs = 1
|
|
||||||
self.batch = 2 ** n
|
|
||||||
self.forward_way = 'pop_batch'
|
|
||||||
|
|
||||||
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
|
|
||||||
|
|
||||||
if self.name == "xor":
|
|
||||||
self.outputs = np.sum(self.inputs, axis=1) % 2
|
|
||||||
elif self.name == "and":
|
|
||||||
self.outputs = np.all(self.inputs==1, axis=1)
|
|
||||||
elif self.name == "or":
|
|
||||||
self.outputs = np.any(self.inputs==1, axis=1)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only support xor, and, or")
|
|
||||||
self.outputs = self.outputs[:, np.newaxis]
|
|
||||||
|
|
||||||
|
|
||||||
def refactor_config(self, config):
|
|
||||||
config.basic.forward_way = self.forward_way
|
|
||||||
config.basic.num_inputs = self.num_inputs
|
|
||||||
config.basic.num_outputs = self.num_outputs
|
|
||||||
config.basic.problem_batch = self.batch
|
|
||||||
|
|
||||||
|
|
||||||
def ask_for_inputs(self):
|
|
||||||
return self.inputs
|
|
||||||
|
|
||||||
def evaluate_predict(self, predict):
|
|
||||||
# print((predict - self.outputs) ** 2)
|
|
||||||
return -np.mean((predict - self.outputs) ** 2)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_permutations(n):
|
|
||||||
permutations = [list(i) for i in product([0, 1], repeat=n)]
|
|
||||||
|
|
||||||
return permutations
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
_ = EnhanceLogic(4)
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import jax
|
|
||||||
|
|
||||||
from problems import Problem
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionFittingProblem(Problem):
|
|
||||||
def __init__(self, num_inputs, num_outputs, batch, inputs, target, loss='MSE'):
|
|
||||||
self.forward_way = 'pop_batch'
|
|
||||||
self.num_inputs = num_inputs
|
|
||||||
self.num_outputs = num_outputs
|
|
||||||
self.batch = batch
|
|
||||||
self.inputs = inputs
|
|
||||||
self.target = target
|
|
||||||
self.loss = loss
|
|
||||||
super().__init__(self.forward_way, self.num_inputs, self.num_outputs, self.batch)
|
|
||||||
|
|
||||||
def evaluate(self, pop_batch_forward):
|
|
||||||
outs = pop_batch_forward(self.inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2))
|
|
||||||
return fitnesses
|
|
||||||
|
|
||||||
def draw(self, batch_func):
|
|
||||||
outs = batch_func(self.inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
print(outs)
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
plt.xlabel('x')
|
|
||||||
plt.ylabel('y')
|
|
||||||
plt.plot(self.inputs, self.target, color='red', label='target')
|
|
||||||
plt.plot(self.inputs, outs, color='blue', label='predict')
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def print(self, batch_func):
|
|
||||||
outs = batch_func(self.inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
print(outs)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
from . import FunctionFittingProblem
|
|
||||||
|
|
||||||
|
|
||||||
class Sin(FunctionFittingProblem):
|
|
||||||
def __init__(self, size=100):
|
|
||||||
self.num_inputs = 1
|
|
||||||
self.num_outputs = 1
|
|
||||||
self.batch = size
|
|
||||||
self.inputs = np.linspace(0, 2 * np.pi, self.batch)[:, None]
|
|
||||||
self.target = np.sin(self.inputs)
|
|
||||||
print(self.inputs, self.target)
|
|
||||||
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
from . import FunctionFittingProblem
|
|
||||||
|
|
||||||
|
|
||||||
class Xor(FunctionFittingProblem):
|
|
||||||
def __init__(self):
|
|
||||||
self.num_inputs = 2
|
|
||||||
self.num_outputs = 1
|
|
||||||
self.batch = 4
|
|
||||||
self.inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
|
||||||
self.target = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
|
||||||
super().__init__(self.num_inputs, self.num_outputs, self.batch, self.inputs, self.target)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
class Problem:
|
|
||||||
def __init__(self, forward_way, num_inputs, num_outputs, batch):
|
|
||||||
self.forward_way = forward_way
|
|
||||||
self.batch = batch
|
|
||||||
self.num_inputs = num_inputs
|
|
||||||
self.num_outputs = num_outputs
|
|
||||||
|
|
||||||
def refactor_config(self, config):
|
|
||||||
config.basic.forward_way = self.forward_way
|
|
||||||
config.basic.num_inputs = self.num_inputs
|
|
||||||
config.basic.num_outputs = self.num_outputs
|
|
||||||
config.basic.problem_batch = self.batch
|
|
||||||
|
|
||||||
def evaluate(self, batch_forward_func):
|
|
||||||
pass
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .config import Configer
|
|
||||||
Reference in New Issue
Block a user