use jit().lower.compile in create functions
This commit is contained in:
@@ -8,29 +8,27 @@ from jax import numpy as jnp
|
|||||||
from .utils import flatten_connections, unflatten_connections
|
from .utils import flatten_connections, unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
def create_crossover_function(batch: bool):
|
def create_crossover_function(N, config, batch: bool):
|
||||||
if batch:
|
if batch:
|
||||||
return batch_crossover
|
pop_size = config.neat.population.pop_size
|
||||||
|
randkey_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||||
|
nodes1_lower = jnp.zeros((pop_size, N, 5))
|
||||||
|
connections1_lower = jnp.zeros((pop_size, 2, N, N))
|
||||||
|
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||||
|
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||||
|
return jit(vmap(crossover)).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||||
|
nodes2_lower, connections2_lower).compile()
|
||||||
else:
|
else:
|
||||||
return crossover
|
randkey_lower = jnp.zeros((2,), dtype=jnp.uint32)
|
||||||
|
nodes1_lower = jnp.zeros((N, 5))
|
||||||
|
connections1_lower = jnp.zeros((2, N, N))
|
||||||
|
nodes2_lower = jnp.zeros((N, 5))
|
||||||
|
connections2_lower = jnp.zeros((2, N, N))
|
||||||
|
return jit(crossover).lower(randkey_lower, nodes1_lower, connections1_lower,
|
||||||
|
nodes2_lower, connections2_lower).compile()
|
||||||
|
|
||||||
|
|
||||||
@vmap
|
# @jit
|
||||||
def batch_crossover(randkeys: Array, batch_nodes1: Array, batch_connections1: Array, batch_nodes2: Array,
|
|
||||||
batch_connections2: Array) -> Tuple[Array, Array]:
|
|
||||||
"""
|
|
||||||
crossover a batch of genomes
|
|
||||||
:param randkeys: batches of random keys
|
|
||||||
:param batch_nodes1:
|
|
||||||
:param batch_connections1:
|
|
||||||
:param batch_nodes2:
|
|
||||||
:param batch_connections2:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return crossover(randkeys, batch_nodes1, batch_connections1, batch_nodes2, batch_connections2)
|
|
||||||
|
|
||||||
|
|
||||||
@jit
|
|
||||||
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array, connections2: Array) \
|
||||||
-> Tuple[Array, Array]:
|
-> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -61,7 +59,7 @@ def crossover(randkey: Array, nodes1: Array, connections1: Array, nodes2: Array,
|
|||||||
return new_nodes, new_cons
|
return new_nodes, new_cons
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=['gene_type'])
|
# @partial(jit, static_argnames=['gene_type'])
|
||||||
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
||||||
"""
|
"""
|
||||||
make ar2 align with ar1.
|
make ar2 align with ar1.
|
||||||
@@ -88,7 +86,7 @@ def align_array(seq1: Array, seq2: Array, ar2: Array, gene_type: str) -> Array:
|
|||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||||
"""
|
"""
|
||||||
crossover two genes
|
crossover two genes
|
||||||
|
|||||||
@@ -6,25 +6,31 @@ from numpy.typing import NDArray
|
|||||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||||
|
|
||||||
|
|
||||||
def create_distance_function(config, type: str):
|
def create_distance_function(N, config, type: str):
|
||||||
"""
|
"""
|
||||||
|
:param N:
|
||||||
:param config:
|
:param config:
|
||||||
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
:param type: {'o2o', 'o2m'}, for one-to-one or one-to-many distance calculation
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
disjoint_coe = config.neat.genome.compatibility_disjoint_coefficient
|
||||||
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
compatibility_coe = config.neat.genome.compatibility_weight_coefficient
|
||||||
|
|
||||||
|
def distance_with_args(nodes1, connections1, nodes2, connections2):
|
||||||
|
return distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||||
|
|
||||||
if type == 'o2o':
|
if type == 'o2o':
|
||||||
return lambda nodes1, connections1, nodes2, connections2: \
|
return lambda nodes1, connections1, nodes2, connections2: \
|
||||||
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||||
|
|
||||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
|
||||||
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
|
||||||
|
|
||||||
elif type == 'o2m':
|
elif type == 'o2m':
|
||||||
func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
|
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||||
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
|
pop_size = config.neat.population.pop_size
|
||||||
func(nodes1, connections1, batch_nodes2, batch_connections2, disjoint_coe, compatibility_coe)
|
nodes1_lower = jnp.zeros((N, 5))
|
||||||
|
connections1_lower = jnp.zeros((2, N, N))
|
||||||
|
nodes2_lower = jnp.zeros((pop_size, N, 5))
|
||||||
|
connections2_lower = jnp.zeros((pop_size, 2, N, N))
|
||||||
|
return jit(vmap_func).lower(nodes1_lower, connections1_lower, nodes2_lower, connections2_lower).compile()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ from jax import numpy as jnp
|
|||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import Array
|
from jax import Array
|
||||||
|
|
||||||
from algorithms.neat.genome.utils import fetch_first, fetch_last
|
from .activations import act_name2key
|
||||||
|
from .aggregations import agg_name2key
|
||||||
|
from .utils import fetch_first
|
||||||
|
|
||||||
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||||
|
|
||||||
@@ -34,10 +36,8 @@ def create_initialize_function(config):
|
|||||||
num_outputs = config.basic.num_outputs
|
num_outputs = config.basic.num_outputs
|
||||||
default_bias = config.neat.gene.bias.init_mean
|
default_bias = config.neat.gene.bias.init_mean
|
||||||
default_response = config.neat.gene.response.init_mean
|
default_response = config.neat.gene.response.init_mean
|
||||||
# default_act = config.neat.gene.activation.default
|
default_act = act_name2key[config.neat.gene.activation.default]
|
||||||
# default_agg = config.neat.gene.aggregation.default
|
default_agg = agg_name2key[config.neat.gene.aggregation.default]
|
||||||
default_act = 0
|
|
||||||
default_agg = 0
|
|
||||||
default_weight = config.neat.gene.weight.init_mean
|
default_weight = config.neat.gene.weight.init_mean
|
||||||
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
||||||
default_act, default_agg, default_weight)
|
default_act, default_agg, default_weight)
|
||||||
|
|||||||
@@ -13,15 +13,19 @@ from .activations import act_name2key
|
|||||||
from .aggregations import agg_name2key
|
from .aggregations import agg_name2key
|
||||||
|
|
||||||
|
|
||||||
def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
def create_mutate_function(N, config, batch: bool):
|
||||||
"""
|
"""
|
||||||
create mutate function for different situations
|
create mutate function for different situations
|
||||||
:param output_keys:
|
:param N:
|
||||||
:param input_keys:
|
|
||||||
:param config:
|
:param config:
|
||||||
:param batch: mutate for population or not
|
:param batch: mutate for population or not
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
num_inputs = config.basic.num_inputs
|
||||||
|
num_outputs = config.basic.num_outputs
|
||||||
|
input_idx = np.arange(num_inputs)
|
||||||
|
output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||||
|
|
||||||
bias = config.neat.gene.bias
|
bias = config.neat.gene.bias
|
||||||
bias_default = bias.init_mean
|
bias_default = bias.init_mean
|
||||||
bias_mean = bias.init_mean
|
bias_mean = bias.init_mean
|
||||||
@@ -65,8 +69,8 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
|||||||
delete_connection_rate = genome.conn_delete_prob
|
delete_connection_rate = genome.conn_delete_prob
|
||||||
single_structure_mutate = genome.single_structural_mutation
|
single_structure_mutate = genome.single_structural_mutation
|
||||||
|
|
||||||
def mutate_func(rand_key, nodes, connections, new_node_key):
|
def mutate_with_args(rand_key, nodes, connections, new_node_key):
|
||||||
return mutate(rand_key, nodes, connections, new_node_key, input_keys, output_keys,
|
return mutate(rand_key, nodes, connections, new_node_key, input_idx, output_idx,
|
||||||
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
bias_default, bias_mean, bias_std, bias_mutate_strength, bias_mutate_rate,
|
||||||
bias_replace_rate, response_default, response_mean, response_std,
|
bias_replace_rate, response_default, response_mean, response_std,
|
||||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||||
@@ -77,19 +81,30 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
|
|||||||
single_structure_mutate)
|
single_structure_mutate)
|
||||||
|
|
||||||
if not batch:
|
if not batch:
|
||||||
return mutate_func
|
rand_key_lower = jnp.zeros((2, ), dtype=jnp.uint32)
|
||||||
|
nodes_lower = jnp.zeros((N, 5))
|
||||||
|
connections_lower = jnp.zeros((2, N, N))
|
||||||
|
new_node_key_lower = jnp.zeros((), dtype=jnp.int32)
|
||||||
|
return jit(mutate_with_args).lower(rand_key_lower, nodes_lower, connections_lower, new_node_key_lower).compile()
|
||||||
else:
|
else:
|
||||||
batched_mutate_func = vmap(mutate_func, in_axes=(0, 0, 0, 0))
|
pop_size = config.neat.population.pop_size
|
||||||
|
rand_key_lower = jnp.zeros((pop_size, 2), dtype=jnp.uint32)
|
||||||
|
nodes_lower = jnp.zeros((pop_size, N, 5))
|
||||||
|
connections_lower = jnp.zeros((pop_size, 2, N, N))
|
||||||
|
new_node_key_lower = jnp.zeros((pop_size, ), dtype=jnp.int32)
|
||||||
|
batched_mutate_func = jit(vmap(mutate_with_args)).lower(rand_key_lower, nodes_lower,
|
||||||
|
connections_lower, new_node_key_lower).compile()
|
||||||
|
|
||||||
return batched_mutate_func
|
return batched_mutate_func
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=["single_structure_mutate"])
|
# @partial(jit, static_argnames=["single_structure_mutate"])
|
||||||
def mutate(rand_key: Array,
|
def mutate(rand_key: Array,
|
||||||
nodes: Array,
|
nodes: Array,
|
||||||
connections: Array,
|
connections: Array,
|
||||||
new_node_key: int,
|
new_node_key: int,
|
||||||
input_keys: Array,
|
input_idx: Array,
|
||||||
output_keys: Array,
|
output_idx: Array,
|
||||||
bias_default: float = 0,
|
bias_default: float = 0,
|
||||||
bias_mean: float = 0,
|
bias_mean: float = 0,
|
||||||
bias_std: float = 1,
|
bias_std: float = 1,
|
||||||
@@ -120,8 +135,8 @@ def mutate(rand_key: Array,
|
|||||||
delete_connection_rate: float = 0.4,
|
delete_connection_rate: float = 0.4,
|
||||||
single_structure_mutate: bool = True):
|
single_structure_mutate: bool = True):
|
||||||
"""
|
"""
|
||||||
:param output_keys:
|
:param output_idx:
|
||||||
:param input_keys:
|
:param input_idx:
|
||||||
:param agg_default:
|
:param agg_default:
|
||||||
:param act_default:
|
:param act_default:
|
||||||
:param response_default:
|
:param response_default:
|
||||||
@@ -166,10 +181,10 @@ def mutate(rand_key: Array,
|
|||||||
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||||
|
|
||||||
def m_delete_node(rk, n, c):
|
def m_delete_node(rk, n, c):
|
||||||
return mutate_delete_node(rk, n, c, input_keys, output_keys)
|
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||||
|
|
||||||
def m_add_connection(rk, n, c):
|
def m_add_connection(rk, n, c):
|
||||||
return mutate_add_connection(rk, n, c, input_keys, output_keys)
|
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||||
|
|
||||||
def m_delete_connection(rk, n, c):
|
def m_delete_connection(rk, n, c):
|
||||||
return mutate_delete_connection(rk, n, c)
|
return mutate_delete_connection(rk, n, c)
|
||||||
@@ -224,7 +239,7 @@ def mutate(rand_key: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_values(rand_key: Array,
|
def mutate_values(rand_key: Array,
|
||||||
nodes: Array,
|
nodes: Array,
|
||||||
connections: Array,
|
connections: Array,
|
||||||
@@ -305,7 +320,7 @@ def mutate_values(rand_key: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: float,
|
||||||
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
mutate_strength: float, mutate_rate: float, replace_rate: float) -> Array:
|
||||||
"""
|
"""
|
||||||
@@ -338,7 +353,7 @@ def mutate_float_values(rand_key: Array, old_vals: Array, mean: float, std: floa
|
|||||||
return new_vals
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace_rate: float) -> Array:
|
||||||
"""
|
"""
|
||||||
Mutate integer values (act, agg) of a given array.
|
Mutate integer values (act, agg) of a given array.
|
||||||
@@ -361,7 +376,7 @@ def mutate_int_values(rand_key: Array, old_vals: Array, val_list: Array, replace
|
|||||||
return new_vals
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connections: Array,
|
||||||
default_bias: float = 0, default_response: float = 1,
|
default_bias: float = 0, default_response: float = 1,
|
||||||
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
default_act: int = 0, default_agg: int = 0) -> Tuple[Array, Array]:
|
||||||
@@ -408,7 +423,7 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -442,7 +457,7 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
||||||
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
input_keys: Array, output_keys: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
@@ -481,7 +496,7 @@ def mutate_add_connection(rand_key: Array, nodes: Array, connections: Array,
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||||
"""
|
"""
|
||||||
Randomly delete a connection.
|
Randomly delete a connection.
|
||||||
@@ -504,7 +519,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
|||||||
return nodes, connections
|
return nodes, connections
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
# @partial(jit, static_argnames=('allow_input_keys', 'allow_output_keys'))
|
||||||
def choice_node_key(rand_key: Array, nodes: Array,
|
def choice_node_key(rand_key: Array, nodes: Array,
|
||||||
input_keys: Array, output_keys: Array,
|
input_keys: Array, output_keys: Array,
|
||||||
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
|
||||||
@@ -533,7 +548,7 @@ def choice_node_key(rand_key: Array, nodes: Array,
|
|||||||
return key, idx
|
return key, idx
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> Tuple[Array, Array, Array, Array]:
|
||||||
"""
|
"""
|
||||||
Randomly choose a connection key from the given connections.
|
Randomly choose a connection key from the given connections.
|
||||||
@@ -561,6 +576,6 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
|||||||
return from_key, to_key, from_idx, to_idx
|
return from_key, to_key, from_idx, to_idx
|
||||||
|
|
||||||
|
|
||||||
@jit
|
# @jit
|
||||||
def rand(rand_key):
|
def rand(rand_key):
|
||||||
return jax.random.uniform(rand_key, ())
|
return jax.random.uniform(rand_key, ())
|
||||||
|
|||||||
@@ -28,10 +28,8 @@ class Pipeline:
|
|||||||
self.species_controller = SpeciesController(config)
|
self.species_controller = SpeciesController(config)
|
||||||
self.initialize_func = create_initialize_function(config)
|
self.initialize_func = create_initialize_function(config)
|
||||||
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx = self.initialize_func()
|
||||||
self.mutate_func = create_mutate_function(config, self.input_idx, self.output_idx, batch=True)
|
|
||||||
self.crossover_func = create_crossover_function(batch=True)
|
self.compile_functions()
|
||||||
self.o2o_distance = create_distance_function(self.config, type='o2o')
|
|
||||||
self.o2m_distance = create_distance_function(self.config, type='o2m')
|
|
||||||
|
|
||||||
self.generation = 0
|
self.generation = 0
|
||||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
|
self.species_controller.speciate(self.pop_nodes, self.pop_connections,
|
||||||
@@ -142,6 +140,15 @@ class Pipeline:
|
|||||||
for s in self.species_controller.species.values():
|
for s in self.species_controller.species.values():
|
||||||
s.representative = expand_single(*s.representative, self.N)
|
s.representative = expand_single(*s.representative, self.N)
|
||||||
|
|
||||||
|
# update functions
|
||||||
|
self.compile_functions()
|
||||||
|
|
||||||
|
def compile_functions(self):
|
||||||
|
self.mutate_func = create_mutate_function(self.N, self.config, batch=True)
|
||||||
|
self.crossover_func = create_crossover_function(self.N, self.config, batch=True)
|
||||||
|
self.o2o_distance = create_distance_function(self.N, self.config, type='o2o')
|
||||||
|
self.o2m_distance = create_distance_function(self.N, self.config, type='o2m')
|
||||||
|
|
||||||
def default_analysis(self, fitnesses):
|
def default_analysis(self, fitnesses):
|
||||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
||||||
class Species(object):
|
class Species(object):
|
||||||
|
|
||||||
def __init__(self, key, generation):
|
def __init__(self, key, generation):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import jax.numpy as jnp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from jax import random
|
from jax import random
|
||||||
from jax import vmap, jit
|
from jax import vmap, jit
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from examples.time_utils import using_cprofile
|
from examples.time_utils import using_cprofile
|
||||||
|
|
||||||
@@ -16,28 +17,43 @@ def func(x, y):
|
|||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
def func2(x, y, s):
|
||||||
|
"""
|
||||||
|
:param x: (100, )
|
||||||
|
:param y: (100,
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if s == '123':
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def func3(x, y):
|
||||||
|
return func2(x, y, '123')
|
||||||
|
|
||||||
|
|
||||||
# @using_cprofile
|
# @using_cprofile
|
||||||
def main():
|
def main():
|
||||||
key = jax.random.PRNGKey(42)
|
key = jax.random.PRNGKey(42)
|
||||||
|
|
||||||
x1, y1 = jax.random.normal(key, shape=(100,)), jax.random.normal(key, shape=(100,))
|
x1, y1 = jax.random.normal(key, shape=(1000,)), jax.random.normal(key, shape=(1000,))
|
||||||
|
|
||||||
jit_func = jit(func)
|
jit_lower_func = jit(func).lower(1, 2).compile()
|
||||||
|
|
||||||
z = jit_func(x1, y1)
|
|
||||||
print(z)
|
|
||||||
|
|
||||||
jit_lower_func = jit(func).lower(x1, y1).compile()
|
|
||||||
print(type(jit_lower_func))
|
print(type(jit_lower_func))
|
||||||
import pickle
|
print(jit_lower_func.memory_analysis())
|
||||||
|
|
||||||
with open('jit_function.pkl', 'wb') as f:
|
jit_compiled_func2 = jit(func2, static_argnames=['s']).lower(x1, y1, '123').compile()
|
||||||
pickle.dump(jit_lower_func, f)
|
print(jit_compiled_func2(x1, y1))
|
||||||
|
|
||||||
new_jit_lower_func = pickle.load(open('jit_function.pkl', 'rb'))
|
# print(jit_compiled_func2(x1, y1))
|
||||||
|
|
||||||
print(jit_lower_func(x1, y1))
|
f = func3.lower(x1, y1).compile()
|
||||||
print(new_jit_lower_func(x1, y1))
|
|
||||||
|
print(f(x1, y1))
|
||||||
|
|
||||||
|
# print(jit_lower_func(x1, y1))
|
||||||
|
|
||||||
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
|
# x2, y2 = jax.random.normal(key, shape=(200,)), jax.random.normal(key, shape=(200,))
|
||||||
# print(jit_lower_func(x2, y2))
|
# print(jit_lower_func(x2, y2))
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
|
|||||||
return fitnesses.tolist() # returns a list
|
return fitnesses.tolist() # returns a list
|
||||||
|
|
||||||
|
|
||||||
@using_cprofile
|
# @using_cprofile
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
pipeline = Pipeline(config, seed=11323)
|
pipeline = Pipeline(config, seed=11323)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"basic": {
|
"basic": {
|
||||||
"num_inputs": 2,
|
"num_inputs": 2,
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"init_maximum_nodes": 25,
|
"init_maximum_nodes": 10,
|
||||||
"expands_coe": 2
|
"expands_coe": 2
|
||||||
},
|
},
|
||||||
"neat": {
|
"neat": {
|
||||||
|
|||||||
Reference in New Issue
Block a user