use jit().lower.compile in create functions

This commit is contained in:
wls2002
2023-05-08 02:35:04 +08:00
parent 497d89fc69
commit d4a75b9394
9 changed files with 120 additions and 77 deletions

View File

@@ -22,7 +22,9 @@ from jax import numpy as jnp
from jax import jit
from jax import Array
from algorithms.neat.genome.utils import fetch_first, fetch_last
from .activations import act_name2key
from .aggregations import agg_name2key
from .utils import fetch_first
EMPTY_NODE = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
@@ -34,10 +36,8 @@ def create_initialize_function(config):
num_outputs = config.basic.num_outputs
default_bias = config.neat.gene.bias.init_mean
default_response = config.neat.gene.response.init_mean
# default_act = config.neat.gene.activation.default
# default_agg = config.neat.gene.aggregation.default
default_act = 0
default_agg = 0
default_act = act_name2key[config.neat.gene.activation.default]
default_agg = agg_name2key[config.neat.gene.aggregation.default]
default_weight = config.neat.gene.weight.init_mean
return partial(initialize_genomes, pop_size, N, num_inputs, num_outputs, default_bias, default_response,
default_act, default_agg, default_weight)