use jit().lower.compile in create functions
This commit is contained in:
@@ -13,15 +13,19 @@ from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
|
||||
|
||||
def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
def create_mutate_function(N, config, batch: bool):
|
||||
"""
|
||||
create mutate function for different situations
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:param N:
|
||||
:param config:
|
||||
:param batch: mutate for population or not
|
||||
:return:
|
||||
"""
|
||||
num_inputs = config.basic.num_inputs
|
||||
num_outputs = config.basic.num_outputs
|
||||
input_idx = np.arange(num_inputs)
|
||||
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
|
||||
bias = config.neat.gene.bias
|
||||
bias_default = bias.init_mean
|
||||
bias_mean = bias.init_mean
|
||||
@@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
delete_connection_rate = genome.conn_delete_prob
|
||||
single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def mutate_func(rand_key, nodes, connections, new_node_key):
|
||||
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
||||
def mutate_with_args(rand_key, nodes, connections, new_node_key):
|
||||
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
|
||||
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
||||
bias_replace_rate, response_default, response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||
@@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
single_structure_mutate)
|
||||
|
||||
if not batch:
|
||||
return mutate_func
|
||||
rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
|
||||
nodes_lower = jnp.zeros((N, 5))
|
||||
connections_lower = jnp.zeros((2, N, N))
|
||||
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
|
||||
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
|
||||
else:
|
||||
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||
nodes_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
|
||||
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
|
||||
connections_lower, new_node_key_lower).compile()
|
||||
|
||||
return batched_mutate_func
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["single_structure_mutate"])
|
||||
# @partial(jit, static_argnames=["single_structure_mutate"])
|
||||
def mutate(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
new_node_key: int,
|
||||
input_keys: Array,
|
||||
output_keys: Array,
|
||||
input_idx: Array,
|
||||
output_idx: Array,
|
||||
bias_default: float = 0,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
@@ -120,8 +135,8 @@ def mutate(rand_key: Array,
|
||||
delete_connection_rate: float = 0.4,
|
||||
single_structure_mutate: bool = True):
|
||||
"""
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:param output_idx:
|
||||
:param input_idx:
|
||||
:param agg_default:
|
||||
:param act_default:
|
||||
:param response_default:
|
||||
@@ -166,10 +181,10 @@ def mutate(rand_key: Array,
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_keys, output_keys)
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_keys, output_keys)
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
@@ -224,7 +239,7 @@ def mutate(rand_key: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_values(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
@@ -305,7 +320,7 @@ def mutate_values(rand_key: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||
"""
|
||||
@@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate integer values (act, agg) of a given array.
|
||||
@@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||
default_bias: float = 0, default_response: float = 1,
|
||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||
@@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
@@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
@@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
||||
return key, idx
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
@@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def rand(rand_key):
|
||||
return jax.random.uniform(rand_key, ())
|
||||
|
||||
Reference in New Issue
Block a user