|
|
|
|
@@ -2,12 +2,15 @@ from typing import Tuple
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import jax
|
|
|
|
|
import numpy as np
|
|
|
|
|
from jax import numpy as jnp
|
|
|
|
|
from jax import jit, vmap, Array
|
|
|
|
|
|
|
|
|
|
from .utils import fetch_random, fetch_first, I_INT
|
|
|
|
|
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
|
|
|
|
|
from .graph import check_cycles
|
|
|
|
|
from .activations import act_name2key
|
|
|
|
|
from .aggregations import agg_name2key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
|
|
|
|
@@ -43,15 +46,13 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
|
|
|
|
weight_replace_rate = weight.replace_rate
|
|
|
|
|
|
|
|
|
|
activation = config.neat.gene.activation
|
|
|
|
|
# act_default = activation.default
|
|
|
|
|
act_default = 0
|
|
|
|
|
act_range = len(activation.options)
|
|
|
|
|
act_default = act_name2key[activation.default]
|
|
|
|
|
act_list = np.array([act_name2key[name] for name in activation.options])
|
|
|
|
|
act_replace_rate = activation.mutate_rate
|
|
|
|
|
|
|
|
|
|
aggregation = config.neat.gene.aggregation
|
|
|
|
|
# agg_default = aggregation.default
|
|
|
|
|
agg_default = 0
|
|
|
|
|
agg_range = len(aggregation.options)
|
|
|
|
|
agg_default = agg_name2key[aggregation.default]
|
|
|
|
|
agg_list = np.array([agg_name2key[name] for name in aggregation.options])
|
|
|
|
|
agg_replace_rate = aggregation.mutate_rate
|
|
|
|
|
|
|
|
|
|
enabled = config.neat.gene.enabled
|
|
|
|
|
@@ -64,29 +65,22 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
|
|
|
|
delete_connection_rate = genome.conn_delete_prob
|
|
|
|
|
single_structure_mutate = genome.single_structural_mutation
|
|
|
|
|
|
|
|
|
|
def mutate_func(rand_key, nodes, connections, new_node_key):
|
|
|
|
|
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
|
|
|
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
|
|
|
|
bias_replace_rate, response_default, response_mean, response_std,
|
|
|
|
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
|
|
|
|
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
|
|
|
|
weight_replace_rate, act_default, act_list, act_replace_rate,
|
|
|
|
|
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
|
|
|
|
|
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
|
|
|
|
single_structure_mutate)
|
|
|
|
|
|
|
|
|
|
if not batch:
|
|
|
|
|
return lambda rand_key, nodes, connections, new_node_key: \
|
|
|
|
|
mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
|
|
|
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
|
|
|
|
bias_replace_rate, response_default, response_mean, response_std,
|
|
|
|
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
|
|
|
|
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
|
|
|
|
weight_replace_rate, act_default, act_range, act_replace_rate,
|
|
|
|
|
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
|
|
|
|
|
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
|
|
|
|
single_structure_mutate)
|
|
|
|
|
return mutate_func
|
|
|
|
|
else:
|
|
|
|
|
batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31))
|
|
|
|
|
return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \
|
|
|
|
|
batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys,
|
|
|
|
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
|
|
|
|
bias_replace_rate, response_default, response_mean, response_std,
|
|
|
|
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
|
|
|
|
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
|
|
|
|
|
weight_replace_rate, act_default, act_range, act_replace_rate,
|
|
|
|
|
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
|
|
|
|
|
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
|
|
|
|
|
single_structure_mutate)
|
|
|
|
|
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
|
|
|
|
|
return batched_mutate_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jit, static_argnames=["single_structure_mutate"])
|
|
|
|
|
@@ -114,10 +108,10 @@ def mutate(rand_key: Array,
|
|
|
|
|
weight_mutate_rate: float = 0.7,
|
|
|
|
|
weight_replace_rate: float = 0.1,
|
|
|
|
|
act_default: int = 0,
|
|
|
|
|
act_range: int = 5,
|
|
|
|
|
act_list: Array = None,
|
|
|
|
|
act_replace_rate: float = 0.1,
|
|
|
|
|
agg_default: int = 0,
|
|
|
|
|
agg_range: int = 5,
|
|
|
|
|
agg_list: Array = None,
|
|
|
|
|
agg_replace_rate: float = 0.1,
|
|
|
|
|
enabled_reverse_rate: float = 0.1,
|
|
|
|
|
add_node_rate: float = 0.2,
|
|
|
|
|
@@ -151,9 +145,9 @@ def mutate(rand_key: Array,
|
|
|
|
|
:param weight_mutate_strength:
|
|
|
|
|
:param weight_mutate_rate:
|
|
|
|
|
:param weight_replace_rate:
|
|
|
|
|
:param act_range:
|
|
|
|
|
:param act_list:
|
|
|
|
|
:param act_replace_rate:
|
|
|
|
|
:param agg_range:
|
|
|
|
|
:param agg_list:
|
|
|
|
|
:param agg_replace_rate:
|
|
|
|
|
:param enabled_reverse_rate:
|
|
|
|
|
:param add_node_rate:
|
|
|
|
|
@@ -224,7 +218,7 @@ def mutate(rand_key: Array,
|
|
|
|
|
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
|
|
|
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
|
|
|
|
weight_mean, weight_std, weight_mutate_strength,
|
|
|
|
|
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
|
|
|
|
|
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
|
|
|
|
|
agg_replace_rate, enabled_reverse_rate)
|
|
|
|
|
|
|
|
|
|
return nodes, connections
|
|
|
|
|
@@ -249,9 +243,9 @@ def mutate_values(rand_key: Array,
|
|
|
|
|
weight_mutate_strength: float = 0.5,
|
|
|
|
|
weight_mutate_rate: float = 0.7,
|
|
|
|
|
weight_replace_rate: float = 0.1,
|
|
|
|
|
act_range: int = 5,
|
|
|
|
|
act_list: Array = None,
|
|
|
|
|
act_replace_rate: float = 0.1,
|
|
|
|
|
agg_range: int = 5,
|
|
|
|
|
agg_list: Array = None,
|
|
|
|
|
agg_replace_rate: float = 0.1,
|
|
|
|
|
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
|
|
|
|
|
"""
|
|
|
|
|
@@ -276,9 +270,9 @@ def mutate_values(rand_key: Array,
|
|
|
|
|
weight_mutate_strength: Strength of the weight mutation.
|
|
|
|
|
weight_mutate_rate: Rate of the weight mutation.
|
|
|
|
|
weight_replace_rate: Rate of the weight replacement.
|
|
|
|
|
act_range: Range of the activation function values.
|
|
|
|
|
act_list: List of the activation function values.
|
|
|
|
|
act_replace_rate: Rate of the activation function replacement.
|
|
|
|
|
agg_range: Range of the aggregation function values.
|
|
|
|
|
agg_list: List of the aggregation function values.
|
|
|
|
|
agg_replace_rate: Rate of the aggregation function replacement.
|
|
|
|
|
enabled_reverse_rate: Rate of reversing enabled state of connections.
|
|
|
|
|
|
|
|
|
|
@@ -293,8 +287,8 @@ def mutate_values(rand_key: Array,
|
|
|
|
|
response_mutate_strength, response_mutate_rate, response_replace_rate)
|
|
|
|
|
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
|
|
|
|
|
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
|
|
|
|
|
act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate)
|
|
|
|
|
agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate)
|
|
|
|
|
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
|
|
|
|
|
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
|
|
|
|
|
|
|
|
|
|
# refactor enabled
|
|
|
|
|
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
|
|
|
|
|
@@ -345,21 +339,21 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array:
|
|
|
|
|
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
|
|
|
|
"""
|
|
|
|
|
Mutate integer values (act, agg) of a given array.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
rand_key: A random key for generating random values.
|
|
|
|
|
old_vals: A 1D array of integer values to be mutated.
|
|
|
|
|
range: Range of the integer values.
|
|
|
|
|
val_list: List of the integer values.
|
|
|
|
|
replace_rate: Rate of the replacement.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A mutated 1D array of integer values.
|
|
|
|
|
"""
|
|
|
|
|
k1, k2, rand_key = jax.random.split(rand_key, num=3)
|
|
|
|
|
replace_val = jax.random.randint(k1, old_vals.shape, 0, range)
|
|
|
|
|
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
|
|
|
|
|
r = jax.random.uniform(k2, old_vals.shape)
|
|
|
|
|
new_vals = old_vals
|
|
|
|
|
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)
|
|
|
|
|
|