use jit().lower.compile in create functions
This commit is contained in:
@@ -22,7 +22,9 @@ from jax import numpy as jnp
|
||||
from jax import jit
|
||||
from jax import Array
|
||||
|
||||
from algorithms.neat.genome.utils import fetch_first, fetch_last
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
from .utils import fetch_first
|
||||
|
||||
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
|
||||
|
||||
@@ -34,10 +36,8 @@ def create_initialize_function(config):
|
||||
num_outputs = config.basic.num_outputs
|
||||
default_bias = config.neat.gene.bias.init_mean
|
||||
default_response = config.neat.gene.response.init_mean
|
||||
# default_act = config.neat.gene.activation.default
|
||||
# default_agg = config.neat.gene.aggregation.default
|
||||
default_act = 0
|
||||
default_agg = 0
|
||||
default_act = act_name2key[config.neat.gene.activation.default]
|
||||
default_agg = agg_name2key[config.neat.gene.aggregation.default]
|
||||
default_weight = config.neat.gene.weight.init_mean
|
||||
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
|
||||
default_act, default_agg, default_weight)
|
||||
|
||||
Reference in New Issue
Block a user