use jit().lower.compile in create functions
This commit is contained in:
@@ -8,29 +8,27 @@ from jax import numpy as jnp
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
def create_crossover_function(batch: bool):
|
||||
def create_crossover_function(N, config, batch: bool):
|
||||
if batch:
|
||||
return batch_crossover
|
||||
pop_size = config.neat.population.pop_size
|
||||
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||
nodes1_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections1_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
else:
|
||||
return crossover
|
||||
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
|
||||
nodes1_lower = jnp.zeros((N, 5))
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((N, 5))
|
||||
connections2_lower = jnp.zeros((2, N, N))
|
||||
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||
nodes2_lower, connections2_lower).compile()
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
crossover a batch of genomes
|
||||
:param randkeys: batches of random keys
|
||||
:param batch_nodes1:
|
||||
:param batch_connections1:
|
||||
:param batch_nodes2:
|
||||
:param batch_connections2:
|
||||
:return:
|
||||
"""
|
||||
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||
-> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['gene_type'])
|
||||
# @partial(jit, static_argnames=['gene_type'])
|
||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
"""
|
||||
make ar2 align with ar1.
|
||||
@@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||
return refactor_ar2
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
"""
|
||||
crossover two genes
|
||||
|
||||
@@ -6,25 +6,31 @@ from numpy.typing import NDArray
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
def create_distance_function(config, type: str):
|
||||
def create_distance_function(N, config, type: str):
|
||||
"""
|
||||
:param N:
|
||||
:param config:
|
||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||
:return:
|
||||
"""
|
||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||
|
||||
def distance_with_args(nodes1, connections1, nodes2, connections2):
|
||||
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
if type == 'o2o':
|
||||
return lambda nodes1, connections1, nodes2, connections2: \
|
||||
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
||||
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
elif type == 'o2m':
|
||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
nodes1_lower = jnp.zeros((N, 5))
|
||||
connections1_lower = jnp.zeros((2, N, N))
|
||||
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
||||
else:
|
||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||
|
||||
|
||||
@@ -22,7 +22,9 @@ from jax import numpy as jnp
|
||||
from jax import jit
|
||||
from jax import Array
|
||||
|
||||
from algorithms.neat.genome.utils import fetch_first, fetch_last
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
from .utils import fetch_first
|
||||
|
||||
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||
|
||||
@@ -34,10 +36,8 @@ def create_initialize_function(config):
|
||||
num_outputs = config.basic.num_outputs
|
||||
default_bias = config.neat.gene.bias.init_mean
|
||||
default_response = config.neat.gene.response.init_mean
|
||||
# default_act = config.neat.gene.activation.default
|
||||
# default_agg = config.neat.gene.aggregation.default
|
||||
default_act = 0
|
||||
default_agg = 0
|
||||
default_act = act_name2key[config.neat.gene.activation.default]
|
||||
default_agg = agg_name2key[config.neat.gene.aggregation.default]
|
||||
default_weight = config.neat.gene.weight.init_mean
|
||||
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
||||
default_act, default_agg, default_weight)
|
||||
|
||||
@@ -13,15 +13,19 @@ from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
|
||||
|
||||
def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
def create_mutate_function(N, config, batch: bool):
|
||||
"""
|
||||
create mutate function for different situations
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:param N:
|
||||
:param config:
|
||||
:param batch: mutate for population or not
|
||||
:return:
|
||||
"""
|
||||
num_inputs = config.basic.num_inputs
|
||||
num_outputs = config.basic.num_outputs
|
||||
input_idx = np.arange(num_inputs)
|
||||
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
|
||||
bias = config.neat.gene.bias
|
||||
bias_default = bias.init_mean
|
||||
bias_mean = bias.init_mean
|
||||
@@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
delete_connection_rate = genome.conn_delete_prob
|
||||
single_structure_mutate = genome.single_structural_mutation
|
||||
|
||||
def mutate_func(rand_key, nodes, connections, new_node_key):
|
||||
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
||||
def mutate_with_args(rand_key, nodes, connections, new_node_key):
|
||||
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
|
||||
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
||||
bias_replace_rate, response_default, response_mean, response_std,
|
||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||
@@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
||||
single_structure_mutate)
|
||||
|
||||
if not batch:
|
||||
return mutate_func
|
||||
rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
|
||||
nodes_lower = jnp.zeros((N, 5))
|
||||
connections_lower = jnp.zeros((2, N, N))
|
||||
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
|
||||
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
|
||||
else:
|
||||
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||
nodes_lower = jnp.zeros((pop_size, N, 5))
|
||||
connections_lower = jnp.zeros((pop_size, 2, N, N))
|
||||
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
|
||||
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
|
||||
connections_lower, new_node_key_lower).compile()
|
||||
|
||||
return batched_mutate_func
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["single_structure_mutate"])
|
||||
# @partial(jit, static_argnames=["single_structure_mutate"])
|
||||
def mutate(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
new_node_key: int,
|
||||
input_keys: Array,
|
||||
output_keys: Array,
|
||||
input_idx: Array,
|
||||
output_idx: Array,
|
||||
bias_default: float = 0,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
@@ -120,8 +135,8 @@ def mutate(rand_key: Array,
|
||||
delete_connection_rate: float = 0.4,
|
||||
single_structure_mutate: bool = True):
|
||||
"""
|
||||
:param output_keys:
|
||||
:param input_keys:
|
||||
:param output_idx:
|
||||
:param input_idx:
|
||||
:param agg_default:
|
||||
:param act_default:
|
||||
:param response_default:
|
||||
@@ -166,10 +181,10 @@ def mutate(rand_key: Array,
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_keys, output_keys)
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_add_connection(rk, n, c):
|
||||
return mutate_add_connection(rk, n, c, input_keys, output_keys)
|
||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||
|
||||
def m_delete_connection(rk, n, c):
|
||||
return mutate_delete_connection(rk, n, c)
|
||||
@@ -224,7 +239,7 @@ def mutate(rand_key: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_values(rand_key: Array,
|
||||
nodes: Array,
|
||||
connections: Array,
|
||||
@@ -305,7 +320,7 @@ def mutate_values(rand_key: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||
"""
|
||||
@@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||
"""
|
||||
Mutate integer values (act, agg) of a given array.
|
||||
@@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
||||
return new_vals
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||
default_bias: float = 0, default_response: float = 1,
|
||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||
@@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
@@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
"""
|
||||
Randomly delete a connection.
|
||||
@@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||
def choice_node_key(rand_key: Array, nodes: Array,
|
||||
input_keys: Array, output_keys: Array,
|
||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||
@@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
||||
return key, idx
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
@@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
|
||||
|
||||
@jit
|
||||
# @jit
|
||||
def rand(rand_key):
|
||||
return jax.random.uniform(rand_key, ())
|
||||
|
||||
@@ -28,10 +28,8 @@ class Pipeline:
|
||||
self.species_controller = SpeciesController(config)
|
||||
self.initialize_func = create_initialize_function(config)
|
||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
||||
self.crossover_func = create_crossover_function(batch=True)
|
||||
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
||||
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
||||
|
||||
self.compile_functions()
|
||||
|
||||
self.generation = 0
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
|
||||
@@ -142,6 +140,15 @@ class Pipeline:
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N)
|
||||
|
||||
# update functions
|
||||
self.compile_functions()
|
||||
|
||||
def compile_functions(self):
|
||||
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
|
||||
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
|
||||
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
|
||||
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||
|
||||
@@ -5,6 +5,7 @@ import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Species(object):
|
||||
|
||||
def __init__(self, key, generation):
|
||||
|
||||
@@ -3,6 +3,7 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import random
|
||||
from jax import vmap, jit
|
||||
from functools import partial
|
||||
|
||||
from examples.time_utils import using_cprofile
|
||||
|
||||
@@ -16,28 +17,43 @@ def func(x, y):
|
||||
return x * y
|
||||
|
||||
|
||||
def func2(x, y, s):
|
||||
"""
|
||||
:param x: (100, )
|
||||
:param y: (100,
|
||||
:return:
|
||||
"""
|
||||
if s == '123':
|
||||
return 0
|
||||
else:
|
||||
return x * y
|
||||
|
||||
|
||||
@jit
|
||||
def func3(x, y):
|
||||
return func2(x, y, '123')
|
||||
|
||||
|
||||
# @using_cprofile
|
||||
def main():
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,))
|
||||
x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,))
|
||||
|
||||
jit_func = jit(func)
|
||||
|
||||
z = jit_func(x1, y1)
|
||||
print(z)
|
||||
|
||||
jit_lower_func = jit(func).lower(x1, y1).compile()
|
||||
jit_lower_func = jit(func).lower(1, 2).compile()
|
||||
print(type(jit_lower_func))
|
||||
import pickle
|
||||
print(jit_lower_func.memory_analysis())
|
||||
|
||||
with open('jit_function.pkl', 'wb') as f:
|
||||
pickle.dump(jit_lower_func, f)
|
||||
jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile()
|
||||
print(jit_compiled_func2(x1, y1))
|
||||
|
||||
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb'))
|
||||
# print(jit_compiled_func2(x1, y1))
|
||||
|
||||
print(jit_lower_func(x1, y1))
|
||||
print(new_jit_lower_func(x1, y1))
|
||||
f = func3.lower(x1, y1).compile()
|
||||
|
||||
print(f(x1, y1))
|
||||
|
||||
# print(jit_lower_func(x1, y1))
|
||||
|
||||
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
|
||||
# print(jit_lower_func(x2, y2))
|
||||
|
||||
@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
|
||||
@using_cprofile
|
||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
# @using_cprofile
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config, seed=11323)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"basic": {
|
||||
"num_inputs": 2,
|
||||
"num_outputs": 1,
|
||||
"init_maximum_nodes": 25,
|
||||
"init_maximum_nodes": 10,
|
||||
"expands_coe": 2
|
||||
},
|
||||
"neat": {
|
||||
|
||||
Reference in New Issue
Block a user