add debug mode for create_xx_functions for detail time cost analysis

This commit is contained in:
wls2002
2023-05-08 15:42:25 +08:00
parent d4a75b9394
commit e201d03157
8 changed files with 70 additions and 38 deletions

View File

@@ -8,7 +8,7 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(N, config, batch: bool):
def create_crossover_function(N, config, batch: bool, debug: bool = False):
if batch:
pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
@@ -16,16 +16,27 @@ def create_crossover_function(N, config, batch: bool):
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
# @jit

View File

@@ -6,11 +6,12 @@ from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(N, config, type: str):
def create_distance_function(N, config, type: str, debug: bool = False):
"""
:param N:
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:param debug:
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
@@ -20,8 +21,20 @@ def create_distance_function(N, config, type: str):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
# return lambda nodes1, connections1, nodes2, connections2: \
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
@@ -30,7 +43,12 @@ def create_distance_function(N, config, type: str):
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
if debug:
return lambda *args: res_func(*args) # for debug
else:
return res_func
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
@@ -48,6 +66,7 @@ def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
:param compatibility_coe:
:return:
"""
def analysis(nodes, connections):
nodes_dict = {}
idx2key = {}

View File

@@ -86,6 +86,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
return vals[output_idx]
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -106,6 +107,7 @@ def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Arr
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
@@ -126,6 +128,7 @@ def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Arra
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:

View File

@@ -74,6 +74,7 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res
@jit
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""

View File

@@ -13,12 +13,13 @@ from .activations import act_name2key
from .aggregations import agg_name2key
def create_mutate_function(N, config, batch: bool):
def create_mutate_function(N, config, batch: bool, debug: bool = False):
"""
create mutate function for different situations
:param N:
:param config:
:param batch: mutate for population or not
:param debug:
:return:
"""
num_inputs = config.basic.num_inputs
@@ -81,24 +82,31 @@ def create_mutate_function(N, config, batch: bool):
single_structure_mutate)
if not batch:
rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes_lower = jnp.zeros((N, 5))
connections_lower = jnp.zeros((2, N, N))
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
if debug:
return lambda *args: res_func(*args)
else:
return res_func
else:
pop_size = config.neat.population.pop_size
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes_lower = jnp.zeros((pop_size, N, 5))
connections_lower = jnp.zeros((pop_size, 2, N, N))
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32)
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
return batched_mutate_func
if debug:
return lambda *args: batched_mutate_func(*args)
else:
return batched_mutate_func
# @partial(jit, static_argnames=["single_structure_mutate"])
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
@@ -239,7 +247,6 @@ def mutate(rand_key: Array,
return nodes, connections
# @jit
def mutate_values(rand_key: Array,
nodes: Array,
connections: Array,
@@ -320,7 +327,6 @@ def mutate_values(rand_key: Array,
return nodes, connections
# @jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
@@ -353,7 +359,6 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
return new_vals
# @jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
@@ -376,7 +381,6 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
return new_vals
# @jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
@@ -423,7 +427,6 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
return nodes, connections
# @jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -457,7 +460,6 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
# @jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -496,7 +498,6 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
# @jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
Randomly delete a connection.
@@ -519,7 +520,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
return nodes, connections
# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -548,7 +548,6 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx
# @jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
@@ -576,6 +575,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
return from_key, to_key, from_idx, to_idx
# @jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -29,7 +29,7 @@ class Pipeline:
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.compile_functions()
self.compile_functions(debug=True)
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
@@ -141,13 +141,13 @@ class Pipeline:
s.representative = expand_single(*s.representative, self.N)
# update functions
self.compile_functions()
self.compile_functions(debug=True)
def compile_functions(self):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
def compile_functions(self, debug=False):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug)
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug)
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)

View File

@@ -105,7 +105,7 @@ class SpeciesController:
# the representatives of new species
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
jax.device_get(o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]))
for r in rid
]
distances = np.array(distances)

View File

@@ -2,7 +2,7 @@
"basic": {
"num_inputs": 2,
"num_outputs": 1,
"init_maximum_nodes": 10,
"init_maximum_nodes": 20,
"expands_coe": 2
},
"neat": {
@@ -30,12 +30,12 @@
},
"activation": {
"default": "sigmoid",
"options": ["sigmoid", "gauss", "relu"],
"options": ["sigmoid"],
"mutate_rate": 0.1
},
"aggregation": {
"default": "sum",
"options": ["sum", "max", "min", "mean"],
"options": ["sum"],
"mutate_rate": 0.1
},
"weight": {
@@ -59,7 +59,7 @@
"node_delete_prob": 0.2
},
"species": {
"compatibility_threshold": 3,
"compatibility_threshold": 2.5,
"species_fitness_func": "max",
"max_stagnation": 20,
"species_elitism": 2,