clean imports and delete "create_XXX_functions"

This commit is contained in:
wls2002
2023-05-09 01:58:00 +08:00
parent f63a0c447b
commit 1f2327bbd6
7 changed files with 20 additions and 286 deletions

View File

@@ -13,100 +13,7 @@ from .activations import act_name2key
from .aggregations import agg_name2key
def create_mutate_function(N, config, batch: bool, debug: bool = False):
"""
create mutate function for different situations
:param N:
:param config:
:param batch: mutate for population or not
:param debug:
:return:
"""
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
bias = config.neat.gene.bias
bias_default = bias.init_mean
bias_mean = bias.init_mean
bias_std = bias.init_stdev
bias_mutate_strength = bias.mutate_power
bias_mutate_rate = bias.mutate_rate
bias_replace_rate = bias.replace_rate
response = config.neat.gene.response
response_default = response.init_mean
response_mean = response.init_mean
response_std = response.init_stdev
response_mutate_strength = response.mutate_power
response_mutate_rate = response.mutate_rate
response_replace_rate = response.replace_rate
weight = config.neat.gene.weight
weight_mean = weight.init_mean
weight_std = weight.init_stdev
weight_mutate_strength = weight.mutate_power
weight_mutate_rate = weight.mutate_rate
weight_replace_rate = weight.replace_rate
activation = config.neat.gene.activation
act_default = act_name2key[activation.default]
act_list = np.array([act_name2key[name] for name in activation.options])
act_replace_rate = activation.mutate_rate
aggregation = config.neat.gene.aggregation
agg_default = agg_name2key[aggregation.default]
agg_list = np.array([agg_name2key[name] for name in aggregation.options])
agg_replace_rate = aggregation.mutate_rate
enabled = config.neat.gene.enabled
enabled_reverse_rate = enabled.mutate_rate
genome = config.neat.genome
add_node_rate = genome.node_add_prob
delete_node_rate = genome.node_delete_prob
add_connection_rate = genome.conn_add_prob
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
def mutate_with_args(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
weight_mean, weight_std, weight_mutate_strength, weight_mutate_rate,
weight_replace_rate, act_default, act_list, act_replace_rate,
agg_default, agg_list, agg_replace_rate, enabled_reverse_rate,
add_node_rate, delete_node_rate, add_connection_rate, delete_connection_rate,
single_structure_mutate)
if not batch:
rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes_lower = jnp.zeros((N, 5))
connections_lower = jnp.zeros((2, N, N))
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
pop_size = config.neat.population.pop_size
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes_lower = jnp.zeros((pop_size, N, 5))
connections_lower = jnp.zeros((pop_size, 2, N, N))
new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32)
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
if debug:
return lambda *args: batched_mutate_func(*args)
else:
return batched_mutate_func
@partial(jit, static_argnames=('single_structure_mutate',))
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
@@ -243,6 +150,7 @@ def mutate(rand_key: Array,
return nodes, connections
@jit
def mutate_values(rand_key: Array,
nodes: Array,
connections: Array,
@@ -323,6 +231,7 @@ def mutate_values(rand_key: Array,
return nodes, connections
@jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
@@ -355,6 +264,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
return new_vals
@jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
@@ -377,6 +287,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
return new_vals
@jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
@@ -423,6 +334,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
return nodes, connections
@jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -456,6 +368,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
@jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -494,6 +407,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
@jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
Randomly delete a connection.
@@ -516,6 +430,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
return nodes, connections
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -544,6 +459,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx
@jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
@@ -571,5 +487,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
return from_key, to_key, from_idx, to_idx
@jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())