clean imports and delete "create_XXX_functions"
This commit is contained in:
@@ -6,11 +6,7 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from jax import jit, vmap
|
from jax import jit, vmap
|
||||||
|
|
||||||
from .genome import act_name2key, agg_name2key
|
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
|
||||||
from .genome.genome import initialize_genomes
|
|
||||||
from .genome.mutate import mutate
|
|
||||||
from .genome.distance import distance
|
|
||||||
from .genome.crossover import crossover
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionFactory:
|
class FunctionFactory:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
from .genome import expand, expand_single, pop_analysis, initialize_genomes
|
||||||
from .distance import create_distance_function
|
|
||||||
from .mutate import create_mutate_function
|
|
||||||
from .forward import create_forward_function
|
from .forward import create_forward_function
|
||||||
from .crossover import create_crossover_function
|
|
||||||
from .activations import act_name2key
|
from .activations import act_name2key
|
||||||
from .aggregations import agg_name2key
|
from .aggregations import agg_name2key
|
||||||
|
from .crossover import crossover
|
||||||
|
from .mutate import mutate
|
||||||
|
from .distance import distance
|
||||||
|
|||||||
@@ -8,38 +8,7 @@ from jax import numpy as jnp
|
|||||||
from .utils import flatten_connections, unflatten_connections
|
from .utils import flatten_connections, unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
def create_crossover_function(N, config, batch: bool, debug: bool = False):
|
@jit
|
||||||
if batch:
|
|
||||||
pop_size = config.neat.population.pop_size
|
|
||||||
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
|
||||||
nodes1_lower = jnp.zeros((pop_size, N, 5))
|
|
||||||
connections1_lower = jnp.zeros((pop_size, 2, N, N))
|
|
||||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
|
||||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
|
||||||
|
|
||||||
res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
|
||||||
nodes2_lower, connections2_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: res_func(*args)
|
|
||||||
else:
|
|
||||||
return res_func
|
|
||||||
|
|
||||||
else:
|
|
||||||
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
|
|
||||||
nodes1_lower = jnp.zeros((N, 5))
|
|
||||||
connections1_lower = jnp.zeros((2, N, N))
|
|
||||||
nodes2_lower = jnp.zeros((N, 5))
|
|
||||||
connections2_lower = jnp.zeros((2, N, N))
|
|
||||||
|
|
||||||
res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
|
|
||||||
nodes2_lower, connections2_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: res_func(*args)
|
|
||||||
else:
|
|
||||||
return res_func
|
|
||||||
|
|
||||||
|
|
||||||
# @jit
|
|
||||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||||
-> Tuple[Array, Array]:
|
-> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -70,7 +39,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
|||||||
return new_nodes, new_cons
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
# @partial(jit, static_argnames=['gene_type'])
|
@partial(jit, static_argnames=['gene_type'])
|
||||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||||
"""
|
"""
|
||||||
make ar2 align with ar1.
|
make ar2 align with ar1.
|
||||||
@@ -97,7 +66,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
|||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
|
|
||||||
# @jit
|
@jit
|
||||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||||
"""
|
"""
|
||||||
crossover two genes
|
crossover two genes
|
||||||
|
|||||||
@@ -1,139 +1,9 @@
|
|||||||
from jax import jit, vmap, Array
|
from jax import jit, vmap, Array
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
import numpy as np
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||||
|
|
||||||
|
|
||||||
def create_distance_function(N, config, type: str, debug: bool = False):
|
|
||||||
"""
|
|
||||||
:param N:
|
|
||||||
:param config:
|
|
||||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
|
||||||
:param debug:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
|
||||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
|
||||||
|
|
||||||
def distance_with_args(nodes1, connections1, nodes2, connections2):
|
|
||||||
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
|
||||||
|
|
||||||
if type == 'o2o':
|
|
||||||
nodes1_lower = jnp.zeros((N, 5))
|
|
||||||
connections1_lower = jnp.zeros((2, N, N))
|
|
||||||
nodes2_lower = jnp.zeros((N, 5))
|
|
||||||
connections2_lower = jnp.zeros((2, N, N))
|
|
||||||
|
|
||||||
res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower,
|
|
||||||
nodes2_lower, connections2_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: res_func(*args) # for debug
|
|
||||||
else:
|
|
||||||
return res_func
|
|
||||||
|
|
||||||
elif type == 'o2m':
|
|
||||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
|
||||||
pop_size = config.neat.population.pop_size
|
|
||||||
nodes1_lower = jnp.zeros((N, 5))
|
|
||||||
connections1_lower = jnp.zeros((2, N, N))
|
|
||||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
|
||||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
|
||||||
res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: res_func(*args) # for debug
|
|
||||||
else:
|
|
||||||
return res_func
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
|
||||||
|
|
||||||
|
|
||||||
def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
|
|
||||||
connection2: NDArray, disjoint_coe: float = 1., compatibility_coe: float = 0.5):
|
|
||||||
"""
|
|
||||||
use in o2o distance.
|
|
||||||
o2o can't use vmap, numpy should be faster than jax function
|
|
||||||
:param nodes1:
|
|
||||||
:param connection1:
|
|
||||||
:param nodes2:
|
|
||||||
:param connection2:
|
|
||||||
:param disjoint_coe:
|
|
||||||
:param compatibility_coe:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def analysis(nodes, connections):
|
|
||||||
nodes_dict = {}
|
|
||||||
idx2key = {}
|
|
||||||
for i, node in enumerate(nodes):
|
|
||||||
if np.isnan(node[0]):
|
|
||||||
continue
|
|
||||||
key = int(node[0])
|
|
||||||
nodes_dict[key] = (node[1], node[2], node[3], node[4])
|
|
||||||
idx2key[i] = key
|
|
||||||
|
|
||||||
connections_dict = {}
|
|
||||||
for i in range(connections.shape[1]):
|
|
||||||
for j in range(connections.shape[2]):
|
|
||||||
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
|
|
||||||
continue
|
|
||||||
key = (idx2key[i], idx2key[j])
|
|
||||||
|
|
||||||
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
|
|
||||||
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
|
|
||||||
connections_dict[key] = (weight, enabled)
|
|
||||||
|
|
||||||
return nodes_dict, connections_dict
|
|
||||||
|
|
||||||
nodes1, connections1 = analysis(nodes1, connection1)
|
|
||||||
nodes2, connections2 = analysis(nodes2, connection2)
|
|
||||||
|
|
||||||
nd = 0.0
|
|
||||||
if nodes1 or nodes2: # otherwise, both are empty
|
|
||||||
disjoint_nodes = 0
|
|
||||||
for k2 in nodes2:
|
|
||||||
if k2 not in nodes1:
|
|
||||||
disjoint_nodes += 1
|
|
||||||
|
|
||||||
for k1, n1 in nodes1.items():
|
|
||||||
n2 = nodes2.get(k1)
|
|
||||||
if n2 is None:
|
|
||||||
disjoint_nodes += 1
|
|
||||||
else:
|
|
||||||
if np.isnan(n1[0]): # n1[1] is nan means input nodes
|
|
||||||
continue
|
|
||||||
d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1])
|
|
||||||
d += 1 if n1[2] != n2[2] else 0
|
|
||||||
d += 1 if n1[3] != n2[3] else 0
|
|
||||||
nd += d
|
|
||||||
|
|
||||||
max_nodes = max(len(nodes1), len(nodes2))
|
|
||||||
nd = (compatibility_coe * nd + disjoint_coe * disjoint_nodes) / max_nodes
|
|
||||||
|
|
||||||
cd = 0.0
|
|
||||||
if connections1 or connections2:
|
|
||||||
disjoint_connections = 0
|
|
||||||
for k2 in connections2:
|
|
||||||
if k2 not in connections1:
|
|
||||||
disjoint_connections += 1
|
|
||||||
|
|
||||||
for k1, c1 in connections1.items():
|
|
||||||
c2 = connections2.get(k1)
|
|
||||||
if c2 is None:
|
|
||||||
disjoint_connections += 1
|
|
||||||
else:
|
|
||||||
# Homologous genes compute their own distance value.
|
|
||||||
d = abs(c1[0] - c2[0])
|
|
||||||
d += 1 if c1[1] != c2[1] else 0
|
|
||||||
cd += d
|
|
||||||
max_conn = max(len(connections1), len(connections2))
|
|
||||||
cd = (compatibility_coe * cd + disjoint_coe * disjoint_connections) / max_conn
|
|
||||||
|
|
||||||
return nd + cd
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
|
||||||
compatibility_coe: float = 0.5) -> Array:
|
compatibility_coe: float = 0.5) -> Array:
|
||||||
|
|||||||
@@ -22,27 +22,11 @@ from jax import numpy as jnp
|
|||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import Array
|
from jax import Array
|
||||||
|
|
||||||
from .activations import act_name2key
|
|
||||||
from .aggregations import agg_name2key
|
|
||||||
from .utils import fetch_first
|
from .utils import fetch_first
|
||||||
|
|
||||||
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||||
|
|
||||||
|
|
||||||
def create_initialize_function(config):
|
|
||||||
pop_size = config.neat.population.pop_size
|
|
||||||
N = config.basic.init_maximum_nodes
|
|
||||||
num_inputs = config.basic.num_inputs
|
|
||||||
num_outputs = config.basic.num_outputs
|
|
||||||
default_bias = config.neat.gene.bias.init_mean
|
|
||||||
default_response = config.neat.gene.response.init_mean
|
|
||||||
default_act = act_name2key[config.neat.gene.activation.default]
|
|
||||||
default_agg = agg_name2key[config.neat.gene.aggregation.default]
|
|
||||||
default_weight = config.neat.gene.weight.init_mean
|
|
||||||
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
|
||||||
default_act, default_agg, default_weight)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_genomes(pop_size: int,
|
def initialize_genomes(pop_size: int,
|
||||||
N: int,
|
N: int,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
|
|||||||
@@ -13,100 +13,7 @@ from .activations import act_name2key
|
|||||||
from .aggregations import agg_name2key
|
from .aggregations import agg_name2key
|
||||||
|
|
||||||
|
|
||||||
def create_mutate_function(N, config, batch: bool, debug: bool = False):
|
@partial(jit, static_argnames=('single_structure_mutate',))
|
||||||
"""
|
|
||||||
create mutate function for different situations
|
|
||||||
:param N:
|
|
||||||
:param config:
|
|
||||||
:param batch: mutate for population or not
|
|
||||||
:param debug:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
num_inputs = config.basic.num_inputs
|
|
||||||
num_outputs = config.basic.num_outputs
|
|
||||||
input_idx = np.arange(num_inputs)
|
|
||||||
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
|
||||||
|
|
||||||
bias = config.neat.gene.bias
|
|
||||||
bias_default = bias.init_mean
|
|
||||||
bias_mean = bias.init_mean
|
|
||||||
bias_std = bias.init_stdev
|
|
||||||
bias_mutate_strength = bias.mutate_power
|
|
||||||
bias_mutate_rate = bias.mutate_rate
|
|
||||||
bias_replace_rate = bias.replace_rate
|
|
||||||
|
|
||||||
response = config.neat.gene.response
|
|
||||||
response_default = response.init_mean
|
|
||||||
response_mean = response.init_mean
|
|
||||||
response_std = response.init_stdev
|
|
||||||
response_mutate_strength = response.mutate_power
|
|
||||||
response_mutate_rate = response.mutate_rate
|
|
||||||
response_replace_rate = response.replace_rate
|
|
||||||
|
|
||||||
weight = config.neat.gene.weight
|
|
||||||
weight_mean = weight.init_mean
|
|
||||||
weight_std = weight.init_stdev
|
|
||||||
weight_mutate_strength = weight.mutate_power
|
|
||||||
weight_mutate_rate = weight.mutate_rate
|
|
||||||
weight_replace_rate = weight.replace_rate
|
|
||||||
|
|
||||||
activation = config.neat.gene.activation
|
|
||||||
act_default = act_name2key[activation.default]
|
|
||||||
act_list = np.array([act_name2key[name] for name in activation.options])
|
|
||||||
act_replace_rate = activation.mutate_rate
|
|
||||||
|
|
||||||
aggregation = config.neat.gene.aggregation
|
|
||||||
agg_default = agg_name2key[aggregation.default]
|
|
||||||
agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
|
||||||
agg_replace_rate = aggregation.mutate_rate
|
|
||||||
|
|
||||||
enabled = config.neat.gene.enabled
|
|
||||||
enabled_reverse_rate = enabled.mutate_rate
|
|
||||||
|
|
||||||
genome = config.neat.genome
|
|
||||||
add_node_rate = genome.node_add_prob
|
|
||||||
delete_node_rate = genome.node_delete_prob
|
|
||||||
add_connection_rate = genome.conn_add_prob
|
|
||||||
delete_connection_rate = genome.conn_delete_prob
|
|
||||||
single_structure_mutate = genome.single_structural_mutation
|
|
||||||
|
|
||||||
def mutate_with_args(rand_key, nodes, connections, new_node_key):
|
|
||||||
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
|
|
||||||
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
|
||||||
bias_replace_rate, response_default, response_mean, response_std,
|
|
||||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
|
||||||
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
|
||||||
weight_replace_rate, act_default, act_list, act_replace_rate,
|
|
||||||
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
|
|
||||||
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
|
||||||
single_structure_mutate)
|
|
||||||
|
|
||||||
if not batch:
|
|
||||||
rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32)
|
|
||||||
nodes_lower = jnp.zeros((N, 5))
|
|
||||||
connections_lower = jnp.zeros((2, N, N))
|
|
||||||
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
|
|
||||||
res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower,
|
|
||||||
connections_lower, new_node_key_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: res_func(*args)
|
|
||||||
else:
|
|
||||||
return res_func
|
|
||||||
else:
|
|
||||||
pop_size = config.neat.population.pop_size
|
|
||||||
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
|
||||||
nodes_lower = jnp.zeros((pop_size, N, 5))
|
|
||||||
connections_lower = jnp.zeros((pop_size, 2, N, N))
|
|
||||||
new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32)
|
|
||||||
|
|
||||||
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
|
|
||||||
connections_lower, new_node_key_lower).compile()
|
|
||||||
if debug:
|
|
||||||
return lambda *args: batched_mutate_func(*args)
|
|
||||||
else:
|
|
||||||
return batched_mutate_func
|
|
||||||
|
|
||||||
|
|
||||||
def mutate(rand_key: Array,
|
def mutate(rand_key: Array,
|
||||||
nodes: Array,
|
nodes: Array,
|
||||||
connections: Array,
|
connections: Array,
|
||||||
@@ -243,6 +150,7 @@ def mutate(rand_key: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_values(rand_key: Array,
|
def mutate_values(rand_key: Array,
|
||||||
nodes: Array,
|
nodes: Array,
|
||||||
connections: Array,
|
connections: Array,
|
||||||
@@ -323,6 +231,7 @@ def mutate_values(rand_key: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||||
"""
|
"""
|
||||||
@@ -355,6 +264,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
|||||||
return new_vals
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||||
"""
|
"""
|
||||||
Mutate integer values (act, agg) of a given array.
|
Mutate integer values (act, agg) of a given array.
|
||||||
@@ -377,6 +287,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
|||||||
return new_vals
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||||
default_bias: float = 0, default_response: float = 1,
|
default_bias: float = 0, default_response: float = 1,
|
||||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||||
@@ -423,6 +334,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -456,6 +368,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -494,6 +407,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||||
"""
|
"""
|
||||||
Randomly delete a connection.
|
Randomly delete a connection.
|
||||||
@@ -516,6 +430,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||||
def choice_node_key(rand_key: Array, nodes: Array,
|
def choice_node_key(rand_key: Array, nodes: Array,
|
||||||
input_keys: Array, output_keys: Array,
|
input_keys: Array, output_keys: Array,
|
||||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||||
@@ -544,6 +459,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
|||||||
return key, idx
|
return key, idx
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||||
"""
|
"""
|
||||||
Randomly choose a connection key from the given connections.
|
Randomly choose a connection key from the given connections.
|
||||||
@@ -571,5 +487,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
|||||||
return from_key, to_key, from_idx, to_idx
|
return from_key, to_key, from_idx, to_idx
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def rand(rand_key):
|
def rand(rand_key):
|
||||||
return jax.random.uniform(rand_key, ())
|
return jax.random.uniform(rand_key, ())
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
from .genome import expand, expand_single
|
from .genome import expand, expand_single, create_forward_function
|
||||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
|
|
||||||
create_distance_function, create_crossover_function
|
|
||||||
from .function_factory import FunctionFactory
|
from .function_factory import FunctionFactory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user