diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 5573714..d5eaf2c 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -8,7 +8,7 @@ from jax import numpy as jnp from .utils import flatten_connections, unflatten_connections -def create_crossover_function(N, config, batch: bool): +def create_crossover_function(N, config, batch: bool, debug: bool = False): if batch: pop_size = config.neat.population.pop_size randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) @@ -16,16 +16,27 @@ def create_crossover_function(N, config, batch: bool): connections1_lower = jnp.zeros((pop_size, 2, N, N)) nodes2_lower = jnp.zeros((pop_size, N, 5)) connections2_lower = jnp.zeros((pop_size, 2, N, N)) - return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() + + res_func = jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() + if debug: + return lambda *args: res_func(*args) + else: + return res_func + else: randkey_lower = jnp.zeros((2,), dtype=jnp.uint32) nodes1_lower = jnp.zeros((N, 5)) connections1_lower = jnp.zeros((2, N, N)) nodes2_lower = jnp.zeros((N, 5)) connections2_lower = jnp.zeros((2, N, N)) - return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower, - nodes2_lower, connections2_lower).compile() + + res_func = jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() + if debug: + return lambda *args: res_func(*args) + else: + return res_func # @jit diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index f2ed988..58c24df 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -6,11 +6,12 @@ from numpy.typing import NDArray from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON -def create_distance_function(N, config, type: str): +def create_distance_function(N, config, type: str, debug: bool = False): """ :param N: :param config: :param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation + :param debug: :return: """ disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient @@ -20,8 +21,20 @@ def create_distance_function(N, config, type: str): return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) if type == 'o2o': - return lambda nodes1, connections1, nodes2, connections2: \ - distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + nodes1_lower = jnp.zeros((N, 5)) + connections1_lower = jnp.zeros((2, N, N)) + nodes2_lower = jnp.zeros((N, 5)) + connections2_lower = jnp.zeros((2, N, N)) + + res_func = jit(distance_with_args).lower(nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() + if debug: + return lambda *args: res_func(*args) # for debug + else: + return res_func + + # return lambda nodes1, connections1, nodes2, connections2: \ + # distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) elif type == 'o2m': vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) @@ -30,7 +43,12 @@ def create_distance_function(N, config, type: str): connections1_lower = jnp.zeros((2, N, N)) nodes2_lower = jnp.zeros((pop_size, N, 5)) connections2_lower = jnp.zeros((pop_size, 2, N, N)) - return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() + res_func = jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() + if debug: + return lambda *args: res_func(*args) # for debug + else: + return res_func + else: raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') @@ -48,6 +66,7 @@ def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray, :param compatibility_coe: :return: """ + def analysis(nodes, connections): nodes_dict = {} idx2key = {} diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index 013d556..dc85fd4 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -86,6 +86,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, return vals[output_idx] +@partial(jit, static_argnames=['N']) @partial(vmap, in_axes=(0, None, None, None, None, None, None)) def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, cal_seqs: Array, nodes: Array, connections: Array) -> Array: @@ -106,6 +107,7 @@ def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Arr return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) +@partial(jit, static_argnames=['N']) @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: @@ -126,6 +128,7 @@ def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Arra return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) +@partial(jit, static_argnames=['N']) @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: diff --git a/algorithms/neat/genome/graph.py b/algorithms/neat/genome/graph.py index 85a4410..bfe38d8 100644 --- a/algorithms/neat/genome/graph.py +++ b/algorithms/neat/genome/graph.py @@ -74,6 +74,7 @@ def topological_sort(nodes: Array, connections: Array) -> Array: return res +@jit @vmap def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: """ diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 709ee32..79d462a 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -13,12 +13,13 @@ from .activations import act_name2key from .aggregations import agg_name2key -def create_mutate_function(N, config, batch: bool): +def create_mutate_function(N, config, batch: bool, debug: bool = False): """ create mutate function for different situations :param N: :param config: :param batch: mutate for population or not + :param debug: :return: """ num_inputs = config.basic.num_inputs @@ -81,24 +82,31 @@ def create_mutate_function(N, config, batch: bool): single_structure_mutate) if not batch: - rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32) + rand_key_lower = jnp.zeros((2,), dtype=jnp.uint32) nodes_lower = jnp.zeros((N, 5)) connections_lower = jnp.zeros((2, N, N)) new_node_key_lower = jnp.zeros((), dtype=jnp.int32) - return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile() + res_func = jit(mutate_with_args).lower(rand_key_lower, nodes_lower, + connections_lower, new_node_key_lower).compile() + if debug: + return lambda *args: res_func(*args) + else: + return res_func else: pop_size = config.neat.population.pop_size rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) nodes_lower = jnp.zeros((pop_size, N, 5)) connections_lower = jnp.zeros((pop_size, 2, N, N)) - new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32) + new_node_key_lower = jnp.zeros((pop_size,), dtype=jnp.int32) + batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile() - - return batched_mutate_func + if debug: + return lambda *args: batched_mutate_func(*args) + else: + return batched_mutate_func -# @partial(jit, static_argnames=["single_structure_mutate"]) def mutate(rand_key: Array, nodes: Array, connections: Array, @@ -239,7 +247,6 @@ def mutate(rand_key: Array, return nodes, connections -# @jit def mutate_values(rand_key: Array, nodes: Array, connections: Array, @@ -320,7 +327,6 @@ def mutate_values(rand_key: Array, return nodes, connections -# @jit def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: """ @@ -353,7 +359,6 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa return new_vals -# @jit def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: """ Mutate integer values (act, agg) of a given array. @@ -376,7 +381,6 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace return new_vals -# @jit def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, default_bias: float = 0, default_response: float = 1, default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: @@ -423,7 +427,6 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection return nodes, connections -# @jit def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -457,7 +460,6 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, return nodes, connections -# @jit def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -496,7 +498,6 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, return nodes, connections -# @jit def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): """ Randomly delete a connection. @@ -519,7 +520,6 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): return nodes, connections -# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) def choice_node_key(rand_key: Array, nodes: Array, input_keys: Array, output_keys: Array, allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: @@ -548,7 +548,6 @@ def choice_node_key(rand_key: Array, nodes: Array, return key, idx -# @jit def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: """ Randomly choose a connection key from the given connections. @@ -576,6 +575,5 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T return from_key, to_key, from_idx, to_idx -# @jit def rand(rand_key): return jax.random.uniform(rand_key, ()) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 6a3b516..d8074f8 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -29,7 +29,7 @@ class Pipeline: self.initialize_func = create_initialize_function(config) self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() - self.compile_functions() + self.compile_functions(debug=True) self.generation = 0 self.species_controller.speciate(self.pop_nodes, self.pop_connections, @@ -141,13 +141,13 @@ class Pipeline: s.representative = expand_single(*s.representative, self.N) # update functions - self.compile_functions() + self.compile_functions(debug=True) - def compile_functions(self): - self.mutate_func = create_mutate_function(self.N, self.config, batch=True) - self.crossover_func = create_crossover_function(self.N, self.config, batch=True) - self.o2o_distance = create_distance_function(self.N, self.config, type='o2o') - self.o2m_distance = create_distance_function(self.N, self.config, type='o2m') + def compile_functions(self, debug=False): + self.mutate_func = create_mutate_function(self.N, self.config, batch=True, debug=debug) + self.crossover_func = create_crossover_function(self.N, self.config, batch=True, debug=debug) + self.o2o_distance = create_distance_function(self.N, self.config, type='o2o', debug=debug) + self.o2m_distance = create_distance_function(self.N, self.config, type='o2m', debug=debug) def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 47cfed9..aa8afd6 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -105,7 +105,7 @@ class SpeciesController: # the representatives of new species sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) distances = [ - o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) + jax.device_get(o2o_distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])) for r in rid ] distances = np.array(distances) diff --git a/utils/default_config.json b/utils/default_config.json index 8ee9902..9f81bfb 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,7 +2,7 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 10, + "init_maximum_nodes": 20, "expands_coe": 2 }, "neat": { @@ -30,12 +30,12 @@ }, "activation": { "default": "sigmoid", - "options": ["sigmoid", "gauss", "relu"], + "options": ["sigmoid"], "mutate_rate": 0.1 }, "aggregation": { "default": "sum", - "options": ["sum", "max", "min", "mean"], + "options": ["sum"], "mutate_rate": 0.1 }, "weight": { @@ -59,7 +59,7 @@ "node_delete_prob": 0.2 }, "species": { - "compatibility_threshold": 3, + "compatibility_threshold": 2.5, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,