clean imports and delete "create_XXX_functions"

This commit is contained in:
wls2002
2023-05-09 01:58:00 +08:00
parent f63a0c447b
commit 1f2327bbd6
7 changed files with 20 additions and 286 deletions

View File

@@ -6,11 +6,7 @@ from functools import partial
import numpy as np import numpy as np
from jax import jit, vmap from jax import jit, vmap
from .genome import act_name2key, agg_name2key from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome.genome import initialize_genomes
from .genome.mutate import mutate
from .genome.distance import distance
from .genome.crossover import crossover
class FunctionFactory: class FunctionFactory:

View File

@@ -1,7 +1,7 @@
from .genome import create_initialize_function, expand, expand_single, pop_analysis from .genome import expand, expand_single, pop_analysis, initialize_genomes
from .distance import create_distance_function
from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function
from .crossover import create_crossover_function
from .activations import act_name2key from .activations import act_name2key
from .aggregations import agg_name2key from .aggregations import agg_name2key
from .crossover import crossover
from .mutate import mutate
from .distance import distance

View File

@@ -8,38 +8,7 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections from .utils import flatten_connections, unflatten_connections
def create_crossover_function(N, config, batch: bool, debug: bool = False): @jit
if batch:
pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((pop_size, N, 5))
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
# @jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]: -> Tuple[Array, Array]:
""" """
@@ -70,7 +39,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
return new_nodes, new_cons return new_nodes, new_cons
# @partial(jit, static_argnames=['gene_type']) @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
""" """
make ar2 align with ar1. make ar2 align with ar1.
@@ -97,7 +66,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2 return refactor_ar2
# @jit @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
""" """
crossover two genes crossover two genes

View File

@@ -1,139 +1,9 @@
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
import numpy as np
from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(N, config, type: str, debug: bool = False):
"""
:param N:
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:param debug:
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
def distance_with_args(nodes1, connections1, nodes2, connections2):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o':
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
elif type == 'o2m':
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
pop_size = config.neat.population.pop_size
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
connection2: NDArray, disjoint_coe: float = 1., compatibility_coe: float = 0.5):
"""
use in o2o distance.
o2o can't use vmap, numpy should be faster than jax function
:param nodes1:
:param connection1:
:param nodes2:
:param connection2:
:param disjoint_coe:
:param compatibility_coe:
:return:
"""
def analysis(nodes, connections):
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
nodes_dict[key] = (node[1], node[2], node[3], node[4])
idx2key[i] = key
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
key = (idx2key[i], idx2key[j])
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
nodes1, connections1 = analysis(nodes1, connection1)
nodes2, connections2 = analysis(nodes2, connection2)
nd = 0.0
if nodes1 or nodes2: # otherwise, both are empty
disjoint_nodes = 0
for k2 in nodes2:
if k2 not in nodes1:
disjoint_nodes += 1
for k1, n1 in nodes1.items():
n2 = nodes2.get(k1)
if n2 is None:
disjoint_nodes += 1
else:
if np.isnan(n1[0]): # n1[1] is nan means input nodes
continue
d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1])
d += 1 if n1[2] != n2[2] else 0
d += 1 if n1[3] != n2[3] else 0
nd += d
max_nodes = max(len(nodes1), len(nodes2))
nd = (compatibility_coe * nd + disjoint_coe * disjoint_nodes) / max_nodes
cd = 0.0
if connections1 or connections2:
disjoint_connections = 0
for k2 in connections2:
if k2 not in connections1:
disjoint_connections += 1
for k1, c1 in connections1.items():
c2 = connections2.get(k1)
if c2 is None:
disjoint_connections += 1
else:
# Homologous genes compute their own distance value.
d = abs(c1[0] - c2[0])
d += 1 if c1[1] != c2[1] else 0
cd += d
max_conn = max(len(connections1), len(connections2))
cd = (compatibility_coe * cd + disjoint_coe * disjoint_connections) / max_conn
return nd + cd
@jit @jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array: compatibility_coe: float = 0.5) -> Array:

View File

@@ -22,27 +22,11 @@ from jax import numpy as jnp
from jax import jit from jax import jit
from jax import Array from jax import Array
from .activations import act_name2key
from .aggregations import agg_name2key
from .utils import fetch_first from .utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
def create_initialize_function(config):
pop_size = config.neat.population.pop_size
N = config.basic.init_maximum_nodes
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean
default_act = act_name2key[config.neat.gene.activation.default]
default_agg = agg_name2key[config.neat.gene.aggregation.default]
default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight)
def initialize_genomes(pop_size: int, def initialize_genomes(pop_size: int,
N: int, N: int,
num_inputs: int, num_inputs: int,

View File

@@ -13,100 +13,7 @@ from .activations import act_name2key
from .aggregations import agg_name2key from .aggregations import agg_name2key
def create_mutate_function(N, config, batch: bool, debug: bool = False): @partial(jit, static_argnames=('single_structure_mutate',))
"""
create mutate function for different situations
:param N:
:param config:
:param batch: mutate for population or not
:param debug:
:return:
"""
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
bias = config.neat.gene.bias
bias_default = bias.init_mean
bias_mean = bias.init_mean
bias_std = bias.init_stdev
bias_mutate_strength = bias.mutate_power
bias_mutate_rate = bias.mutate_rate
bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
response_default = response.init_mean
response_mean = response.init_mean
response_std = response.init_stdev
response_mutate_strength = response.mutate_power
response_mutate_rate = response.mutate_rate
response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
weight_mean = weight.init_mean
weight_std = weight.init_stdev
weight_mutate_strength = weight.mutate_power
weight_mutate_rate = weight.mutate_rate
weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
act_default = act_name2key[activation.default]
act_list = np.array([act_name2key[name] for name in activation.options])
act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
agg_default = agg_name2key[aggregation.default]
agg_list = np.array([agg_name2key[name] for name in aggregation.options])
agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
add_node_rate = genome.node_add_prob
delete_node_rate = genome.node_delete_prob
add_connection_rate = genome.conn_add_prob
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
def mutate_with_args(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_list, act_replace_rate,
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
if not batch:
rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes_lower = jnp.zeros((N, 5))
connections_lower = jnp.zeros((2, N, N))
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
pop_size = config.neat.population.pop_size
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes_lower = jnp.zeros((pop_size, N, 5))
connections_lower = jnp.zeros((pop_size, 2, N, N))
new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32)
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
if debug:
return lambda *args: batched_mutate_func(*args)
else:
return batched_mutate_func
def mutate(rand_key: Array, def mutate(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, connections: Array,
@@ -243,6 +150,7 @@ def mutate(rand_key: Array,
return nodes, connections return nodes, connections
@jit
def mutate_values(rand_key: Array, def mutate_values(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, connections: Array,
@@ -323,6 +231,7 @@ def mutate_values(rand_key: Array,
return nodes, connections return nodes, connections
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
""" """
@@ -355,6 +264,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
return new_vals return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
""" """
Mutate integer values (act, agg) of a given array. Mutate integer values (act, agg) of a given array.
@@ -377,6 +287,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
return new_vals return new_vals
@jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1, default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
@@ -423,6 +334,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
return nodes, connections return nodes, connections
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
@@ -456,6 +368,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections return nodes, connections
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
@@ -494,6 +407,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections return nodes, connections
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
""" """
Randomly delete a connection. Randomly delete a connection.
@@ -516,6 +430,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
return nodes, connections return nodes, connections
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array, def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array, input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -544,6 +459,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
""" """
Randomly choose a connection key from the given connections. Randomly choose a connection key from the given connections.
@@ -571,5 +487,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
return from_key, to_key, from_idx, to_idx return from_key, to_key, from_idx, to_idx
@jit
def rand(rand_key): def rand(rand_key):
return jax.random.uniform(rand_key, ()) return jax.random.uniform(rand_key, ())

View File

@@ -5,9 +5,7 @@ import jax
import numpy as np import numpy as np
from .species import SpeciesController from .species import SpeciesController
from .genome import expand, expand_single from .genome import expand, expand_single, create_forward_function
from .genome import create_initialize_function, create_mutate_function, create_forward_function, \
create_distance_function, create_crossover_function
from .function_factory import FunctionFactory from .function_factory import FunctionFactory