diff --git a/algorithms/neat/genome/aggregations.py b/algorithms/neat/genome/aggregations.py index 6cf172e..85c5f02 100644 --- a/algorithms/neat/genome/aggregations.py +++ b/algorithms/neat/genome/aggregations.py @@ -88,13 +88,13 @@ agg_name2key = { def agg(idx, z): idx = jnp.asarray(idx, dtype=jnp.int32) - def full_zero(): + def full_nan(): return 0. - def not_full_zero(): + def not_full_nan(): return jax.lax.switch(idx, AGG_TOTAL_LIST, z) - return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero) + return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan) vectorized_agg = jax.vmap(agg, in_axes=(0, 0)) diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index 8479b07..c6f83ab 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -50,7 +50,7 @@ def initialize_genomes(pop_size: int, default_response: float = 1.0, default_act: int = 0, default_agg: int = 0, - default_weight: float = 1.0) \ + default_weight: float = 0.0) \ -> Tuple[NDArray, NDArray, NDArray, NDArray]: """ Initialize genomes with default values. diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index d92df28..9e180ea 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -2,12 +2,15 @@ from typing import Tuple from functools import partial import jax +import numpy as np from jax import numpy as jnp from jax import jit, vmap, Array from .utils import fetch_random, fetch_first, I_INT from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx from .graph import check_cycles +from .activations import act_name2key +from .aggregations import agg_name2key def create_mutate_function(config, input_keys, output_keys, batch: bool): @@ -43,15 +46,13 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool): weight_replace_rate = weight.replace_rate activation = config.neat.gene.activation - # act_default = activation.default - act_default = 0 - act_range = len(activation.options) + act_default = act_name2key[activation.default] + act_list = np.array([act_name2key[name] for name in activation.options]) act_replace_rate = activation.mutate_rate aggregation = config.neat.gene.aggregation - # agg_default = aggregation.default - agg_default = 0 - agg_range = len(aggregation.options) + agg_default = agg_name2key[aggregation.default] + agg_list = np.array([agg_name2key[name] for name in aggregation.options]) agg_replace_rate = aggregation.mutate_rate enabled = config.neat.gene.enabled @@ -64,29 +65,22 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool): delete_connection_rate = genome.conn_delete_prob single_structure_mutate = genome.single_structural_mutation + def mutate_func(rand_key, nodes, connections, new_node_key): + return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys, + bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, + bias_replace_rate, response_default, response_mean, response_std, + response_mutate_strength, response_mutate_rate, response_replace_rate, + weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, + weight_replace_rate, act_default, act_list, act_replace_rate, + agg_default, agg_list, agg_replace_rate, enabled_reverse_rate, + add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, + single_structure_mutate) + if not batch: - return lambda rand_key, nodes, connections, new_node_key: \ - mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys, - bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, - bias_replace_rate, response_default, response_mean, response_std, - response_mutate_strength, response_mutate_rate, response_replace_rate, - weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, - weight_replace_rate, act_default, act_range, act_replace_rate, - agg_default, agg_range, agg_replace_rate, enabled_reverse_rate, - add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, - single_structure_mutate) + return mutate_func else: - batched_mutate = vmap(mutate, in_axes=(0, 0, 0, 0, *(None,) * 31)) - return lambda rand_keys, pop_nodes, pop_connections, new_node_keys: \ - batched_mutate(rand_keys, pop_nodes, pop_connections, new_node_keys, input_keys, output_keys, - bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, - bias_replace_rate, response_default, response_mean, response_std, - response_mutate_strength, response_mutate_rate, response_replace_rate, - weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, - weight_replace_rate, act_default, act_range, act_replace_rate, - agg_default, agg_range, agg_replace_rate, enabled_reverse_rate, - add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate, - single_structure_mutate) + batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0)) + return batched_mutate_func @partial(jit, static_argnames=["single_structure_mutate"]) @@ -114,10 +108,10 @@ def mutate(rand_key: Array, weight_mutate_rate: float = 0.7, weight_replace_rate: float = 0.1, act_default: int = 0, - act_range: int = 5, + act_list: Array = None, act_replace_rate: float = 0.1, agg_default: int = 0, - agg_range: int = 5, + agg_list: Array = None, agg_replace_rate: float = 0.1, enabled_reverse_rate: float = 0.1, add_node_rate: float = 0.2, @@ -151,9 +145,9 @@ def mutate(rand_key: Array, :param weight_mutate_strength: :param weight_mutate_rate: :param weight_replace_rate: - :param act_range: + :param act_list: :param act_replace_rate: - :param agg_range: + :param agg_list: :param agg_replace_rate: :param enabled_reverse_rate: :param add_node_rate: @@ -224,7 +218,7 @@ def mutate(rand_key: Array, bias_mutate_rate, bias_replace_rate, response_mean, response_std, response_mutate_strength, response_mutate_rate, response_replace_rate, weight_mean, weight_std, weight_mutate_strength, - weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range, + weight_mutate_rate, weight_replace_rate, act_list, act_replace_rate, agg_list, agg_replace_rate, enabled_reverse_rate) return nodes, connections @@ -249,9 +243,9 @@ def mutate_values(rand_key: Array, weight_mutate_strength: float = 0.5, weight_mutate_rate: float = 0.7, weight_replace_rate: float = 0.1, - act_range: int = 5, + act_list: Array = None, act_replace_rate: float = 0.1, - agg_range: int = 5, + agg_list: Array = None, agg_replace_rate: float = 0.1, enabled_reverse_rate: float = 0.1) -> Tuple[Array, Array]: """ @@ -276,9 +270,9 @@ def mutate_values(rand_key: Array, weight_mutate_strength: Strength of the weight mutation. weight_mutate_rate: Rate of the weight mutation. weight_replace_rate: Rate of the weight replacement. - act_range: Range of the activation function values. + act_list: List of the activation function values. act_replace_rate: Rate of the activation function replacement. - agg_range: Range of the aggregation function values. + agg_list: List of the aggregation function values. agg_replace_rate: Rate of the aggregation function replacement. enabled_reverse_rate: Rate of reversing enabled state of connections. @@ -293,8 +287,8 @@ def mutate_values(rand_key: Array, response_mutate_strength, response_mutate_rate, response_replace_rate) weight_new = mutate_float_values(k3, connections[0, :, :], weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate, weight_replace_rate) - act_new = mutate_int_values(k4, nodes[:, 3], act_range, act_replace_rate) - agg_new = mutate_int_values(k5, nodes[:, 4], agg_range, agg_replace_rate) + act_new = mutate_int_values(k4, nodes[:, 3], act_list, act_replace_rate) + agg_new = mutate_int_values(k5, nodes[:, 4], agg_list, agg_replace_rate) # refactor enabled r = jax.random.uniform(rand_key, connections[1, :, :].shape) @@ -345,21 +339,21 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa @jit -def mutate_int_values(rand_key: Array, old_vals: Array, range: int, replace_rate: float) -> Array: +def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: """ Mutate integer values (act, agg) of a given array. Args: rand_key: A random key for generating random values. old_vals: A 1D array of integer values to be mutated. - range: Range of the integer values. + val_list: List of the integer values. replace_rate: Rate of the replacement. Returns: A mutated 1D array of integer values. """ k1, k2, rand_key = jax.random.split(rand_key, num=3) - replace_val = jax.random.randint(k1, old_vals.shape, 0, range) + replace_val = jax.random.choice(k1, val_list, old_vals.shape) r = jax.random.uniform(k2, old_vals.shape) new_vals = old_vals new_vals = jnp.where(r < replace_rate, replace_val, new_vals) diff --git a/utils/default_config.json b/utils/default_config.json index 7b14361..661c030 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -9,8 +9,8 @@ "population": { "fitness_criterion": "max", "fitness_threshold": 76, - "generation_limit": 100, - "pop_size": 100, + "generation_limit": 1000, + "pop_size": 200, "reset_on_extinction": "False" }, "gene": { @@ -30,16 +30,16 @@ }, "activation": { "default": "sigmoid", - "options": ["sigmoid"], - "mutate_rate": 0.01 + "options": ["sigmoid", "gauss", "relu"], + "mutate_rate": 0.1 }, "aggregation": { "default": "sum", - "options": ["sum"], - "mutate_rate": 0.01 + "options": ["sum", "max", "min", "mean"], + "mutate_rate": 0.1 }, "weight": { - "init_mean": 1.0, + "init_mean": 0.0, "init_stdev": 1.0, "mutate_power": 0.5, "mutate_rate": 0.8,