diff --git a/algorithms/neat/genome/crossover.py b/algorithms/neat/genome/crossover.py index 7dca803..5573714 100644 --- a/algorithms/neat/genome/crossover.py +++ b/algorithms/neat/genome/crossover.py @@ -8,29 +8,27 @@ from jax import numpy as jnp from .utils import flatten_connections, unflatten_connections -def create_crossover_function(batch: bool): +def create_crossover_function(N, config, batch: bool): if batch: - return batch_crossover + pop_size = config.neat.population.pop_size + randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) + nodes1_lower = jnp.zeros((pop_size, N, 5)) + connections1_lower = jnp.zeros((pop_size, 2, N, N)) + nodes2_lower = jnp.zeros((pop_size, N, 5)) + connections2_lower = jnp.zeros((pop_size, 2, N, N)) + return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() else: - return crossover + randkey_lower = jnp.zeros((2,), dtype=jnp.uint32) + nodes1_lower = jnp.zeros((N, 5)) + connections1_lower = jnp.zeros((2, N, N)) + nodes2_lower = jnp.zeros((N, 5)) + connections2_lower = jnp.zeros((2, N, N)) + return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower, + nodes2_lower, connections2_lower).compile() -@vmap -def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array, - batch_connections2: Array) -> Tuple[Array, Array]: - """ - crossover a batch of genomes - :param randkeys: batches of random keys - :param batch_nodes1: - :param batch_connections1: - :param batch_nodes2: - :param batch_connections2: - :return: - """ - return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2) - - -@jit +# @jit def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ -> Tuple[Array, Array]: """ @@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, return new_nodes, new_cons -@partial(jit, static_argnames=['gene_type']) +# @partial(jit, static_argnames=['gene_type']) def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: """ make ar2 align with ar1. @@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: return refactor_ar2 -@jit +# @jit def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: """ crossover two genes diff --git a/algorithms/neat/genome/distance.py b/algorithms/neat/genome/distance.py index 74f1df6..f2ed988 100644 --- a/algorithms/neat/genome/distance.py +++ b/algorithms/neat/genome/distance.py @@ -6,25 +6,31 @@ from numpy.typing import NDArray from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON -def create_distance_function(config, type: str): +def create_distance_function(N, config, type: str): """ + :param N: :param config: :param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation :return: """ disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient compatibility_coe = config.neat.genome.compatibility_weight_coefficient + + def distance_with_args(nodes1, connections1, nodes2, connections2): + return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) + if type == 'o2o': return lambda nodes1, connections1, nodes2, connections2: \ distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) - # return lambda nodes1, connections1, nodes2, connections2: \ - # distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) - elif type == 'o2m': - func = vmap(distance, in_axes=(None, None, 0, 0, None, None)) - return lambda nodes1, connections1, batch_nodes2, batch_connections2: \ - func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe) + vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0)) + pop_size = config.neat.population.pop_size + nodes1_lower = jnp.zeros((N, 5)) + connections1_lower = jnp.zeros((2, N, N)) + nodes2_lower = jnp.zeros((pop_size, N, 5)) + connections2_lower = jnp.zeros((pop_size, 2, N, N)) + return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile() else: raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') diff --git a/algorithms/neat/genome/genome.py b/algorithms/neat/genome/genome.py index c6f83ab..53dfcf1 100644 --- a/algorithms/neat/genome/genome.py +++ b/algorithms/neat/genome/genome.py @@ -22,7 +22,9 @@ from jax import numpy as jnp from jax import jit from jax import Array -from algorithms.neat.genome.utils import fetch_first, fetch_last +from .activations import act_name2key +from .aggregations import agg_name2key +from .utils import fetch_first EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) @@ -34,10 +36,8 @@ def create_initialize_function(config): num_outputs = config.basic.num_outputs default_bias = config.neat.gene.bias.init_mean default_response = config.neat.gene.response.init_mean - # default_act = config.neat.gene.activation.default - # default_agg = config.neat.gene.aggregation.default - default_act = 0 - default_agg = 0 + default_act = act_name2key[config.neat.gene.activation.default] + default_agg = agg_name2key[config.neat.gene.aggregation.default] default_weight = config.neat.gene.weight.init_mean return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response, default_act, default_agg, default_weight) diff --git a/algorithms/neat/genome/mutate.py b/algorithms/neat/genome/mutate.py index 9e180ea..709ee32 100644 --- a/algorithms/neat/genome/mutate.py +++ b/algorithms/neat/genome/mutate.py @@ -13,15 +13,19 @@ from .activations import act_name2key from .aggregations import agg_name2key -def create_mutate_function(config, input_keys, output_keys, batch: bool): +def create_mutate_function(N, config, batch: bool): """ create mutate function for different situations - :param output_keys: - :param input_keys: + :param N: :param config: :param batch: mutate for population or not :return: """ + num_inputs = config.basic.num_inputs + num_outputs = config.basic.num_outputs + input_idx = np.arange(num_inputs) + output_idx = np.arange(num_inputs, num_inputs + num_outputs) + bias = config.neat.gene.bias bias_default = bias.init_mean bias_mean = bias.init_mean @@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool): delete_connection_rate = genome.conn_delete_prob single_structure_mutate = genome.single_structural_mutation - def mutate_func(rand_key, nodes, connections, new_node_key): - return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys, + def mutate_with_args(rand_key, nodes, connections, new_node_key): + return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx, bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, bias_replace_rate, response_default, response_mean, response_std, response_mutate_strength, response_mutate_rate, response_replace_rate, @@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool): single_structure_mutate) if not batch: - return mutate_func + rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32) + nodes_lower = jnp.zeros((N, 5)) + connections_lower = jnp.zeros((2, N, N)) + new_node_key_lower = jnp.zeros((), dtype=jnp.int32) + return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile() else: - batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0)) + pop_size = config.neat.population.pop_size + rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32) + nodes_lower = jnp.zeros((pop_size, N, 5)) + connections_lower = jnp.zeros((pop_size, 2, N, N)) + new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32) + batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower, + connections_lower, new_node_key_lower).compile() + return batched_mutate_func -@partial(jit, static_argnames=["single_structure_mutate"]) +# @partial(jit, static_argnames=["single_structure_mutate"]) def mutate(rand_key: Array, nodes: Array, connections: Array, new_node_key: int, - input_keys: Array, - output_keys: Array, + input_idx: Array, + output_idx: Array, bias_default: float = 0, bias_mean: float = 0, bias_std: float = 1, @@ -120,8 +135,8 @@ def mutate(rand_key: Array, delete_connection_rate: float = 0.4, single_structure_mutate: bool = True): """ - :param output_keys: - :param input_keys: + :param output_idx: + :param input_idx: :param agg_default: :param act_default: :param response_default: @@ -166,10 +181,10 @@ def mutate(rand_key: Array, return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default) def m_delete_node(rk, n, c): - return mutate_delete_node(rk, n, c, input_keys, output_keys) + return mutate_delete_node(rk, n, c, input_idx, output_idx) def m_add_connection(rk, n, c): - return mutate_add_connection(rk, n, c, input_keys, output_keys) + return mutate_add_connection(rk, n, c, input_idx, output_idx) def m_delete_connection(rk, n, c): return mutate_delete_connection(rk, n, c) @@ -224,7 +239,7 @@ def mutate(rand_key: Array, return nodes, connections -@jit +# @jit def mutate_values(rand_key: Array, nodes: Array, connections: Array, @@ -305,7 +320,7 @@ def mutate_values(rand_key: Array, return nodes, connections -@jit +# @jit def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: """ @@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa return new_vals -@jit +# @jit def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: """ Mutate integer values (act, agg) of a given array. @@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace return new_vals -@jit +# @jit def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, default_bias: float = 0, default_response: float = 1, default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: @@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection return nodes, connections -@jit +# @jit def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, return nodes, connections -@jit +# @jit def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: """ @@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, return nodes, connections -@jit +# @jit def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): """ Randomly delete a connection. @@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): return nodes, connections -@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) +# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) def choice_node_key(rand_key: Array, nodes: Array, input_keys: Array, output_keys: Array, allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: @@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array, return key, idx -@jit +# @jit def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: """ Randomly choose a connection key from the given connections. @@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T return from_key, to_key, from_idx, to_idx -@jit +# @jit def rand(rand_key): return jax.random.uniform(rand_key, ()) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index b39be0a..6a3b516 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -28,10 +28,8 @@ class Pipeline: self.species_controller = SpeciesController(config) self.initialize_func = create_initialize_function(config) self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() - self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True) - self.crossover_func = create_crossover_function(batch=True) - self.o2o_distance = create_distance_function(self.config, type='o2o') - self.o2m_distance = create_distance_function(self.config, type='o2m') + + self.compile_functions() self.generation = 0 self.species_controller.speciate(self.pop_nodes, self.pop_connections, @@ -142,6 +140,15 @@ class Pipeline: for s in self.species_controller.species.values(): s.representative = expand_single(*s.representative, self.N) + # update functions + self.compile_functions() + + def compile_functions(self): + self.mutate_func = create_mutate_function(self.N, self.config, batch=True) + self.crossover_func = create_crossover_function(self.N, self.config, batch=True) + self.o2o_distance = create_distance_function(self.N, self.config, type='o2o') + self.o2m_distance = create_distance_function(self.N, self.config, type='o2m') + def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) species_sizes = [len(s.members) for s in self.species_controller.species.values()] diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 620de50..47cfed9 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -5,6 +5,7 @@ import jax import numpy as np from numpy.typing import NDArray + class Species(object): def __init__(self, key, generation): diff --git a/examples/jax_playground.py b/examples/jax_playground.py index 2554ab9..2e4487f 100644 --- a/examples/jax_playground.py +++ b/examples/jax_playground.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import numpy as np from jax import random from jax import vmap, jit +from functools import partial from examples.time_utils import using_cprofile @@ -16,28 +17,43 @@ def func(x, y): return x * y +def func2(x, y, s): + """ + :param x: (100, ) + :param y: (100, + :return: + """ + if s == '123': + return 0 + else: + return x * y + + +@jit +def func3(x, y): + return func2(x, y, '123') + + # @using_cprofile def main(): key = jax.random.PRNGKey(42) - x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,)) + x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,)) - jit_func = jit(func) - - z = jit_func(x1, y1) - print(z) - - jit_lower_func = jit(func).lower(x1, y1).compile() + jit_lower_func = jit(func).lower(1, 2).compile() print(type(jit_lower_func)) - import pickle + print(jit_lower_func.memory_analysis()) - with open('jit_function.pkl', 'wb') as f: - pickle.dump(jit_lower_func, f) + jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile() + print(jit_compiled_func2(x1, y1)) - new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb')) + # print(jit_compiled_func2(x1, y1)) - print(jit_lower_func(x1, y1)) - print(new_jit_lower_func(x1, y1)) + f = func3.lower(x1, y1).compile() + + print(f(x1, y1)) + + # print(jit_lower_func(x1, y1)) # x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,)) # print(jit_lower_func(x2, y2)) diff --git a/examples/xor.py b/examples/xor.py index e8ac80c..895bd5f 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]: return fitnesses.tolist() # returns a list -@using_cprofile -# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +# @using_cprofile +@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() pipeline = Pipeline(config, seed=11323) diff --git a/utils/default_config.json b/utils/default_config.json index 1881b52..8ee9902 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,7 +2,7 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 25, + "init_maximum_nodes": 10, "expands_coe": 2 }, "neat": {