use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -8,29 +8,27 @@ from jax import numpy as jnp
from .utils import flatten_connections, unflatten_connections
def create_crossover_function(batch: bool):
def create_crossover_function(N, config, batch: bool):
if batch:
return batch_crossover
pop_size = config.neat.population.pop_size
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((pop_size, N, 5))
connections1_lower = jnp.zeros((pop_size, 2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
else:
return crossover
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((N, 5))
connections2_lower = jnp.zeros((2, N, N))
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
nodes2_lower, connections2_lower).compile()
@vmap
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
batch_connections2: Array) -> Tuple[Array, Array]:
"""
crossover a batch of genomes
:param randkeys: batches of random keys
:param batch_nodes1:
:param batch_connections1:
:param batch_nodes2:
:param batch_connections2:
:return:
"""
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
@jit
# @jit
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
-> Tuple[Array, Array]:
"""
@@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
return new_nodes, new_cons
@partial(jit, static_argnames=['gene_type'])
# @partial(jit, static_argnames=['gene_type'])
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
"""
make ar2 align with ar1.
@@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
return refactor_ar2
@jit
# @jit
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes

View File

@@ -6,25 +6,31 @@ from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
def create_distance_function(config, type: str):
def create_distance_function(N, config, type: str):
"""
:param N:
:param config:
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
:return:
"""
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
def distance_with_args(nodes1, connections1, nodes2, connections2):
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
# return lambda nodes1, connections1, nodes2, connections2: \
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
pop_size = config.neat.population.pop_size
nodes1_lower = jnp.zeros((N, 5))
connections1_lower = jnp.zeros((2, N, N))
nodes2_lower = jnp.zeros((pop_size, N, 5))
connections2_lower = jnp.zeros((pop_size, 2, N, N))
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
else:
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')

View File

@@ -22,7 +22,9 @@ from jax import numpy as jnp
from jax import jit
from jax import Array
from algorithms.neat.genome.utils import fetch_first, fetch_last
from .activations import act_name2key
from .aggregations import agg_name2key
from .utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
@@ -34,10 +36,8 @@ def create_initialize_function(config):
num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean
# default_act = config.neat.gene.activation.default
# default_agg = config.neat.gene.aggregation.default
default_act = 0
default_agg = 0
default_act = act_name2key[config.neat.gene.activation.default]
default_agg = agg_name2key[config.neat.gene.aggregation.default]
default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight)

View File

@@ -13,15 +13,19 @@ from .activations import act_name2key
from .aggregations import agg_name2key
def create_mutate_function(config, input_keys, output_keys, batch: bool):
def create_mutate_function(N, config, batch: bool):
"""
create mutate function for different situations
:param output_keys:
:param input_keys:
:param N:
:param config:
:param batch: mutate for population or not
:return:
"""
num_inputs = config.basic.num_inputs
num_outputs = config.basic.num_outputs
input_idx = np.arange(num_inputs)
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
bias = config.neat.gene.bias
bias_default = bias.init_mean
bias_mean = bias.init_mean
@@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
delete_connection_rate = genome.conn_delete_prob
single_structure_mutate = genome.single_structural_mutation
def mutate_func(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
def mutate_with_args(rand_key, nodes, connections, new_node_key):
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
bias_replace_rate, response_default, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,
@@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
single_structure_mutate)
if not batch:
return mutate_func
rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
nodes_lower = jnp.zeros((N, 5))
connections_lower = jnp.zeros((2, N, N))
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
else:
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
pop_size = config.neat.population.pop_size
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
nodes_lower = jnp.zeros((pop_size, N, 5))
connections_lower = jnp.zeros((pop_size, 2, N, N))
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
connections_lower, new_node_key_lower).compile()
return batched_mutate_func
@partial(jit, static_argnames=["single_structure_mutate"])
# @partial(jit, static_argnames=["single_structure_mutate"])
def mutate(rand_key: Array,
nodes: Array,
connections: Array,
new_node_key: int,
input_keys: Array,
output_keys: Array,
input_idx: Array,
output_idx: Array,
bias_default: float = 0,
bias_mean: float = 0,
bias_std: float = 1,
@@ -120,8 +135,8 @@ def mutate(rand_key: Array,
delete_connection_rate: float = 0.4,
single_structure_mutate: bool = True):
"""
:param output_keys:
:param input_keys:
:param output_idx:
:param input_idx:
:param agg_default:
:param act_default:
:param response_default:
@@ -166,10 +181,10 @@ def mutate(rand_key: Array,
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_keys, output_keys)
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_keys, output_keys)
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
@@ -224,7 +239,7 @@ def mutate(rand_key: Array,
return nodes, connections
@jit
# @jit
def mutate_values(rand_key: Array,
nodes: Array,
connections: Array,
@@ -305,7 +320,7 @@ def mutate_values(rand_key: Array,
return nodes, connections
@jit
# @jit
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
"""
@@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
return new_vals
@jit
# @jit
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
"""
Mutate integer values (act, agg) of a given array.
@@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
return new_vals
@jit
# @jit
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
default_bias: float = 0, default_response: float = 1,
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
@@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
return nodes, connections
@jit
# @jit
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
@jit
# @jit
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
"""
@@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
return nodes, connections
@jit
# @jit
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
Randomly delete a connection.
@@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
return nodes, connections
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
@@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
return key, idx
@jit
# @jit
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
"""
Randomly choose a connection key from the given connections.
@@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
return from_key, to_key, from_idx, to_idx
@jit
# @jit
def rand(rand_key):
return jax.random.uniform(rand_key, ())

View File

@@ -28,10 +28,8 @@ class Pipeline:
self.species_controller = SpeciesController(config)
self.initialize_func = create_initialize_function(config)
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
self.crossover_func = create_crossover_function(batch=True)
self.o2o_distance = create_distance_function(self.config, type='o2o')
self.o2m_distance = create_distance_function(self.config, type='o2m')
self.compile_functions()
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
@@ -142,6 +140,15 @@ class Pipeline:
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N)
# update functions
self.compile_functions()
def compile_functions(self):
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]

View File

@@ -5,6 +5,7 @@ import jax
import numpy as np
from numpy.typing import NDArray
class Species(object):
def __init__(self, key, generation):

View File

@@ -3,6 +3,7 @@ import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
from functools import partial
from examples.time_utils import using_cprofile
@@ -16,28 +17,43 @@ def func(x, y):
return x * y
def func2(x, y, s):
"""
:param x: (100, )
:param y: (100,
:return:
"""
if s == '123':
return 0
else:
return x * y
@jit
def func3(x, y):
return func2(x, y, '123')
# @using_cprofile
def main():
key = jax.random.PRNGKey(42)
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,))
x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,))
jit_func = jit(func)
z = jit_func(x1, y1)
print(z)
jit_lower_func = jit(func).lower(x1, y1).compile()
jit_lower_func = jit(func).lower(1, 2).compile()
print(type(jit_lower_func))
import pickle
print(jit_lower_func.memory_analysis())
with open('jit_function.pkl', 'wb') as f:
pickle.dump(jit_lower_func, f)
jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile()
print(jit_compiled_func2(x1, y1))
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb'))
# print(jit_compiled_func2(x1, y1))
print(jit_lower_func(x1, y1))
print(new_jit_lower_func(x1, y1))
f = func3.lower(x1, y1).compile()
print(f(x1, y1))
# print(jit_lower_func(x1, y1))
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
# print(jit_lower_func(x2, y2))

View File

@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
return fitnesses.tolist() # returns a list
@using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
# @using_cprofile
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
pipeline = Pipeline(config, seed=11323)

View File

@@ -2,7 +2,7 @@
"basic": {
"num_inputs": 2,
"num_outputs": 1,
"init_maximum_nodes": 25,
"init_maximum_nodes": 10,
"expands_coe": 2
},
"neat": {