modify act. agg in mutation. they can have option vals

fix a bug in function 'agg'
This commit is contained in:
wls2002
2023-05-07 23:00:04 +08:00
parent 47bb593a53
commit b257505bee
4 changed files with 46 additions and 52 deletions

View File

@@ -88,13 +88,13 @@ agg_name2key = {
def agg(idx, z): def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32) idx = jnp.asarray(idx, dtype=jnp.int32)
def full_zero(): def full_nan():
return 0. return 0.
def not_full_zero(): def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z) return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero) return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0)) vectorized_agg = jax.vmap(agg, in_axes=(0, 0))

View File

@@ -50,7 +50,7 @@ def initialize_genomes(pop_size: int,
default_response: float = 1.0, default_response: float = 1.0,
default_act: int = 0, default_act: int = 0,
default_agg: int = 0, default_agg: int = 0,
default_weight: float = 1.0) \ default_weight: float = 0.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]: -> Tuple[NDArray, NDArray, NDArray, NDArray]:
""" """
Initialize genomes with default values. Initialize genomes with default values.

View File

@@ -2,12 +2,15 @@ from typing import Tuple
from functools import partial from functools import partial
import jax import jax
import numpy as np
from jax import numpy as jnp from jax import numpy as jnp
from jax import jit, vmap, Array from jax import jit, vmap, Array
from .utils import fetch_random, fetch_first, I_INT from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles from .graph import check_cycles
from .activations import act_name2key
from .aggregations import agg_name2key
def create_mutate_function(config, input_keys, output_keys, batch: bool): def create_mutate_function(config, input_keys, output_keys, batch: bool):
@@ -43,15 +46,13 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
weight_replace_rate = weight.replace_rate weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation activation = config.neat.gene.activation
# act_default = activation.default act_default = act_name2key[activation.default]
act_default = 0 act_list = np.array([act_name2key[name] for name in activation.options])
act_range = len(activation.options)
act_replace_rate = activation.mutate_rate act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation aggregation = config.neat.gene.aggregation
# agg_default = aggregation.default agg_default = agg_name2key[aggregation.default]
agg_default = 0 agg_list = np.array([agg_name2key[name] for name in aggregation.options])
agg_range = len(aggregation.options)
agg_replace_rate = aggregation.mutate_rate agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled enabled = config.neat.gene.enabled
@@ -64,29 +65,22 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
delete_connection_rate = genome.conn_delete_prob delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation single_structure_mutate = genome.single_structural_mutation
def mutate_func(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_list, act_replace_rate,
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
if not batch: if not batch:
return lambda rand_key, nodes, connections, new_node_key: \ return mutate_func
mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
else: else:
batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31)) batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \ return batched_mutate_func
batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
@partial(jit, static_argnames=["single_structure_mutate"]) @partial(jit, static_argnames=["single_structure_mutate"])
@@ -114,10 +108,10 @@ def mutate(rand_key: Array,
weight_mutate_rate: float = 0.7, weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1, weight_replace_rate: float = 0.1,
act_default: int = 0, act_default: int = 0,
act_range: int = 5, act_list: Array = None,
act_replace_rate: float = 0.1, act_replace_rate: float = 0.1,
agg_default: int = 0, agg_default: int = 0,
agg_range: int = 5, agg_list: Array = None,
agg_replace_rate: float = 0.1, agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1, enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2, add_node_rate: float = 0.2,
@@ -151,9 +145,9 @@ def mutate(rand_key: Array,
:param weight_mutate_strength: :param weight_mutate_strength:
:param weight_mutate_rate: :param weight_mutate_rate:
:param weight_replace_rate: :param weight_replace_rate:
:param act_range: :param act_list:
:param act_replace_rate: :param act_replace_rate:
:param agg_range: :param agg_list:
:param agg_replace_rate: :param agg_replace_rate:
:param enabled_reverse_rate: :param enabled_reverse_rate:
:param add_node_rate: :param add_node_rate:
@@ -224,7 +218,7 @@ def mutate(rand_key: Array,
bias_mutate_rate, bias_replace_rate, response_mean, response_std, bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate, response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range, weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
agg_replace_rate, enabled_reverse_rate) agg_replace_rate, enabled_reverse_rate)
return nodes, connections return nodes, connections
@@ -249,9 +243,9 @@ def mutate_values(rand_key: Array,
weight_mutate_strength: float = 0.5, weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7, weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1, weight_replace_rate: float = 0.1,
act_range: int = 5, act_list: Array = None,
act_replace_rate: float = 0.1, act_replace_rate: float = 0.1,
agg_range: int = 5, agg_list: Array = None,
agg_replace_rate: float = 0.1, agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]: enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
""" """
@@ -276,9 +270,9 @@ def mutate_values(rand_key: Array,
weight_mutate_strength: Strength of the weight mutation. weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation. weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement. weight_replace_rate: Rate of the weight replacement.
act_range: Range of the activation function values. act_list: List of the activation function values.
act_replace_rate: Rate of the activation function replacement. act_replace_rate: Rate of the activation function replacement.
agg_range: Range of the aggregation function values. agg_list: List of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement. agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections. enabled_reverse_rate: Rate of reversing enabled state of connections.
@@ -293,8 +287,8 @@ def mutate_values(rand_key: Array,
response_mutate_strength, response_mutate_rate, response_replace_rate) response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std, weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate) weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate) act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate) agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# refactor enabled # refactor enabled
r = jax.random.uniform(rand_key, connections[1, :, :].shape) r = jax.random.uniform(rand_key, connections[1, :, :].shape)
@@ -345,21 +339,21 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
@jit @jit
def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array: def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
""" """
Mutate integer values (act, agg) of a given array. Mutate integer values (act, agg) of a given array.
Args: Args:
rand_key: A random key for generating random values. rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated. old_vals: A 1D array of integer values to be mutated.
range: Range of the integer values. val_list: List of the integer values.
replace_rate: Rate of the replacement. replace_rate: Rate of the replacement.
Returns: Returns:
A mutated 1D array of integer values. A mutated 1D array of integer values.
""" """
k1, k2, rand_key = jax.random.split(rand_key, num=3) k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.randint(k1, old_vals.shape, 0, range) replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape) r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals) new_vals = jnp.where(r < replace_rate, replace_val, new_vals)

View File

@@ -9,8 +9,8 @@
"population": { "population": {
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 76, "fitness_threshold": 76,
"generation_limit": 100, "generation_limit": 1000,
"pop_size": 100, "pop_size": 200,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {
@@ -30,16 +30,16 @@
}, },
"activation": { "activation": {
"default": "sigmoid", "default": "sigmoid",
"options": ["sigmoid"], "options": ["sigmoid", "gauss", "relu"],
"mutate_rate": 0.01 "mutate_rate": 0.1
}, },
"aggregation": { "aggregation": {
"default": "sum", "default": "sum",
"options": ["sum"], "options": ["sum", "max", "min", "mean"],
"mutate_rate": 0.01 "mutate_rate": 0.1
}, },
"weight": { "weight": {
"init_mean": 1.0, "init_mean": 0.0,
"init_stdev": 1.0, "init_stdev": 1.0,
"mutate_power": 0.5, "mutate_power": 0.5,
"mutate_rate": 0.8, "mutate_rate": 0.8,