use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -8,29 +8,27 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool): def create_crossover_function(N, config, batch: bool):
if batch: if batch:
return batch_crossover pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((pop_size, N, 5))
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
else: else:
return crossover randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
@vmap # @jit
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
"""
crossover a batch of genomes
:param randkeys: batches of random keys
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
@jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]: -> Tuple[Array, Array]:
""" """
@@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
return new_nodes, new_cons return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type']) # @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array: def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
""" """
make ar2 align with ar1. make ar2 align with ar1.
@@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2 return refactor_ar2
@jit # @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array: def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
""" """
crossover two genes crossover two genes

View File

@@ -6,25 +6,31 @@ from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str): def create_distance_function(N, config, type: str):
""" """
:param N:
:param config: :param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation :param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return: :return:
""" """
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient compatibility_coe = config.neat.genome.compatibility_weight_coefficient
def distance_with_args(nodes1, connections1, nodes2, connections2):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o': if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \ return lambda nodes1, connections1, nodes2, connections2: \
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
# return lambda nodes1, connections1, nodes2, connections2: \
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m': elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None)) vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \ pop_size = config.neat.population.pop_size
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe) nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
else: else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')

View File

@@ -22,7 +22,9 @@ from jax import numpy as jnp
from jax import jit from jax import jit
from jax import Array from jax import Array
from algorithms.neat.genome.utils import fetch_first, fetch_last from .activations import act_name2key
from .aggregations import agg_name2key
from .utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
@@ -34,10 +36,8 @@ def create_initialize_function(config):
num_outputs = config.basic.num_outputs num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean default_response = config.neat.gene.response.init_mean
# default_act = config.neat.gene.activation.default default_act = act_name2key[config.neat.gene.activation.default]
# default_agg = config.neat.gene.aggregation.default default_agg = agg_name2key[config.neat.gene.aggregation.default]
default_act = 0
default_agg = 0
default_weight = config.neat.gene.weight.init_mean default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response, return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight) default_act, default_agg, default_weight)

View File

@@ -13,15 +13,19 @@ from .activations import act_name2key
from .aggregations import agg_name2key from .aggregations import agg_name2key
def create_mutate_function(config, input_keys, output_keys, batch: bool): def create_mutate_function(N, config, batch: bool):
""" """
create mutate function for different situations create mutate function for different situations
:param output_keys: :param N:
:param input_keys:
:param config: :param config:
:param batch: mutate for population or not :param batch: mutate for population or not
:return: :return:
""" """
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
bias = config.neat.gene.bias bias = config.neat.gene.bias
bias_default = bias.init_mean bias_default = bias.init_mean
bias_mean = bias.init_mean bias_mean = bias.init_mean
@@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
delete_connection_rate = genome.conn_delete_prob delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation single_structure_mutate = genome.single_structural_mutation
def mutate_func(rand_key, nodes, connections, new_node_key): def mutate_with_args(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys, return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate, bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std, bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate, response_mutate_strength, response_mutate_rate, response_replace_rate,
@@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
single_structure_mutate) single_structure_mutate)
if not batch: if not batch:
return mutate_func rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
nodes_lower = jnp.zeros((N, 5))
connections_lower = jnp.zeros((2, N, N))
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
else: else:
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0)) pop_size = config.neat.population.pop_size
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes_lower = jnp.zeros((pop_size, N, 5))
connections_lower = jnp.zeros((pop_size, 2, N, N))
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
return batched_mutate_func return batched_mutate_func
@partial(jit, static_argnames=["single_structure_mutate"]) # @partial(jit, static_argnames=["single_structure_mutate"])
def mutate(rand_key: Array, def mutate(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, connections: Array,
new_node_key: int, new_node_key: int,
input_keys: Array, input_idx: Array,
output_keys: Array, output_idx: Array,
bias_default: float = 0, bias_default: float = 0,
bias_mean: float = 0, bias_mean: float = 0,
bias_std: float = 1, bias_std: float = 1,
@@ -120,8 +135,8 @@ def mutate(rand_key: Array,
delete_connection_rate: float = 0.4, delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True): single_structure_mutate: bool = True):
""" """
:param output_keys: :param output_idx:
:param input_keys: :param input_idx:
:param agg_default: :param agg_default:
:param act_default: :param act_default:
:param response_default: :param response_default:
@@ -166,10 +181,10 @@ def mutate(rand_key: Array,
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default) return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
def m_delete_node(rk, n, c): def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_keys, output_keys) return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_add_connection(rk, n, c): def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_keys, output_keys) return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c): def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c) return mutate_delete_connection(rk, n, c)
@@ -224,7 +239,7 @@ def mutate(rand_key: Array,
return nodes, connections return nodes, connections
@jit # @jit
def mutate_values(rand_key: Array, def mutate_values(rand_key: Array,
nodes: Array, nodes: Array,
connections: Array, connections: Array,
@@ -305,7 +320,7 @@ def mutate_values(rand_key: Array,
return nodes, connections return nodes, connections
@jit # @jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float, def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array: mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
""" """
@@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
return new_vals return new_vals
@jit # @jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array: def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
""" """
Mutate integer values (act, agg) of a given array. Mutate integer values (act, agg) of a given array.
@@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
return new_vals return new_vals
@jit # @jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array, def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1, default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]: default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
@@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
return nodes, connections return nodes, connections
@jit # @jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array, def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
@@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections return nodes, connections
@jit # @jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array, def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]: input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
""" """
@@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections return nodes, connections
@jit # @jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array): def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
""" """
Randomly delete a connection. Randomly delete a connection.
@@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
return nodes, connections return nodes, connections
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys')) # @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array, def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array, input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]: allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx return key, idx
@jit # @jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]: def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
""" """
Randomly choose a connection key from the given connections. Randomly choose a connection key from the given connections.
@@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
return from_key, to_key, from_idx, to_idx return from_key, to_key, from_idx, to_idx
@jit # @jit
def rand(rand_key): def rand(rand_key):
return jax.random.uniform(rand_key, ()) return jax.random.uniform(rand_key, ())

View File

@@ -28,10 +28,8 @@ class Pipeline:
self.species_controller = SpeciesController(config) self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config) self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func() self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True) self.compile_functions()
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
self.generation = 0 self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.species_controller.speciate(self.pop_nodes, self.pop_connections,
@@ -142,6 +140,15 @@ class Pipeline:
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N) s.representative = expand_single(*s.representative, self.N)
# update functions
self.compile_functions()
def compile_functions(self):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
def default_analysis(self, fitnesses): def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()] species_sizes = [len(s.members) for s in self.species_controller.species.values()]

View File

@@ -5,6 +5,7 @@ import jax
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
class Species(object): class Species(object):
def __init__(self, key, generation): def __init__(self, key, generation):

View File

@@ -3,6 +3,7 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from jax import random from jax import random
from jax import vmap, jit from jax import vmap, jit
from functools import partial
from examples.time_utils import using_cprofile from examples.time_utils import using_cprofile
@@ -16,28 +17,43 @@ def func(x, y):
return x * y return x * y
def func2(x, y, s):
"""
:param x: (100, )
:param y: (100,
:return:
"""
if s == '123':
return 0
else:
return x * y
@jit
def func3(x, y):
return func2(x, y, '123')
# @using_cprofile # @using_cprofile
def main(): def main():
key = jax.random.PRNGKey(42) key = jax.random.PRNGKey(42)
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,)) x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,))
jit_func = jit(func) jit_lower_func = jit(func).lower(1, 2).compile()
z = jit_func(x1, y1)
print(z)
jit_lower_func = jit(func).lower(x1, y1).compile()
print(type(jit_lower_func)) print(type(jit_lower_func))
import pickle print(jit_lower_func.memory_analysis())
with open('jit_function.pkl', 'wb') as f: jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile()
pickle.dump(jit_lower_func, f) print(jit_compiled_func2(x1, y1))
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb')) # print(jit_compiled_func2(x1, y1))
print(jit_lower_func(x1, y1)) f = func3.lower(x1, y1).compile()
print(new_jit_lower_func(x1, y1))
print(f(x1, y1))
# print(jit_lower_func(x1, y1))
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,)) # x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
# print(jit_lower_func(x2, y2)) # print(jit_lower_func(x2, y2))

View File

@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list
@using_cprofile # @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config, seed=11323) pipeline = Pipeline(config, seed=11323)

View File

@@ -2,7 +2,7 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 25, "init_maximum_nodes": 10,
"expands_coe": 2 "expands_coe": 2
}, },
"neat": { "neat": {