modify act. agg in mutation. they can have option vals

fix a bug in function 'agg'
This commit is contained in:
wls2002
2023-05-07 23:00:04 +08:00
parent 47bb593a53
commit b257505bee
4 changed files with 46 additions and 52 deletions

View File

@@ -88,13 +88,13 @@ agg_name2key = {
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_zero():
def full_nan():
return 0.
def not_full_zero():
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))

View File

@@ -50,7 +50,7 @@ def initialize_genomes(pop_size: int,
default_response: float = 1.0,
default_act: int = 0,
default_agg: int = 0,
default_weight: float = 1.0) \
default_weight: float = 0.0) \
-> Tuple[NDArray, NDArray, NDArray, NDArray]:
"""
Initialize genomes with default values.

View File

@@ -2,12 +2,15 @@ from typing import Tuple
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jax import jit, vmap, Array
from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
from .activations import act_name2key
from .aggregations import agg_name2key
def create_mutate_function(config, input_keys, output_keys, batch: bool):
@@ -43,15 +46,13 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
# act_default = activation.default
act_default = 0
act_range = len(activation.options)
act_default = act_name2key[activation.default]
act_list = np.array([act_name2key[name] for name in activation.options])
act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
# agg_default = aggregation.default
agg_default = 0
agg_range = len(aggregation.options)
agg_default = agg_name2key[aggregation.default]
agg_list = np.array([agg_name2key[name] for name in aggregation.options])
agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
@@ -64,29 +65,22 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
def mutate_func(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_list, act_replace_rate,
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
if not batch:
return lambda rand_key, nodes, connections, new_node_key: \
mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
return mutate_func
else:
batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31))
return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \
batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_range, act_replace_rate,
agg_default, agg_range, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
return batched_mutate_func
@partial(jit, static_argnames=["single_structure_mutate"])
@@ -114,10 +108,10 @@ def mutate(rand_key: Array,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_default: int = 0,
act_range: int = 5,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_default: int = 0,
agg_range: int = 5,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1,
add_node_rate: float = 0.2,
@@ -151,9 +145,9 @@ def mutate(rand_key: Array,
:param weight_mutate_strength:
:param weight_mutate_rate:
:param weight_replace_rate:
:param act_range:
:param act_list:
:param act_replace_rate:
:param agg_range:
:param agg_list:
:param agg_replace_rate:
:param enabled_reverse_rate:
:param add_node_rate:
@@ -224,7 +218,7 @@ def mutate(rand_key: Array,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list,
agg_replace_rate, enabled_reverse_rate)
return nodes, connections
@@ -249,9 +243,9 @@ def mutate_values(rand_key: Array,
weight_mutate_strength: float = 0.5,
weight_mutate_rate: float = 0.7,
weight_replace_rate: float = 0.1,
act_range: int = 5,
act_list: Array = None,
act_replace_rate: float = 0.1,
agg_range: int = 5,
agg_list: Array = None,
agg_replace_rate: float = 0.1,
enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]:
"""
@@ -276,9 +270,9 @@ def mutate_values(rand_key: Array,
weight_mutate_strength: Strength of the weight mutation.
weight_mutate_rate: Rate of the weight mutation.
weight_replace_rate: Rate of the weight replacement.
act_range: Range of the activation function values.
act_list: List of the activation function values.
act_replace_rate: Rate of the activation function replacement.
agg_range: Range of the aggregation function values.
agg_list: List of the aggregation function values.
agg_replace_rate: Rate of the aggregation function replacement.
enabled_reverse_rate: Rate of reversing enabled state of connections.
@@ -293,8 +287,8 @@ def mutate_values(rand_key: Array,
response_mutate_strength, response_mutate_rate, response_replace_rate)
weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std,
weight_mutate_strength, weight_mutate_rate, weight_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate)
act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate)
agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate)
# refactor enabled
r = jax.random.uniform(rand_key, connections[1, :, :].shape)
@@ -345,21 +339,21 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array:
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
Args:
rand_key: A random key for generating random values.
old_vals: A 1D array of integer values to be mutated.
range: Range of the integer values.
val_list: List of the integer values.
replace_rate: Rate of the replacement.
Returns:
A mutated 1D array of integer values.
"""
k1, k2, rand_key = jax.random.split(rand_key, num=3)
replace_val = jax.random.randint(k1, old_vals.shape, 0, range)
replace_val = jax.random.choice(k1, val_list, old_vals.shape)
r = jax.random.uniform(k2, old_vals.shape)
new_vals = old_vals
new_vals = jnp.where(r < replace_rate, replace_val, new_vals)

View File

@@ -9,8 +9,8 @@
"population": {
"fitness_criterion": "max",
"fitness_threshold": 76,
"generation_limit": 100,
"pop_size": 100,
"generation_limit": 1000,
"pop_size": 200,
"reset_on_extinction": "False"
},
"gene": {
@@ -30,16 +30,16 @@
},
"activation": {
"default": "sigmoid",
"options": ["sigmoid"],
"mutate_rate": 0.01
"options": ["sigmoid", "gauss", "relu"],
"mutate_rate": 0.1
},
"aggregation": {
"default": "sum",
"options": ["sum"],
"mutate_rate": 0.01
"options": ["sum", "max", "min", "mean"],
"mutate_rate": 0.1
},
"weight": {
"init_mean": 1.0,
"init_mean": 0.0,
"init_stdev": 1.0,
"mutate_power": 0.5,
"mutate_rate": 0.8,