finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -4,8 +4,8 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .activations import refactor_act
|
||||
from .aggregations import refactor_agg
|
||||
from neat.genome.activations import act_name2func
|
||||
from neat.genome.aggregations import agg_name2func
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
@@ -20,12 +20,12 @@ jit_config_keys = [
|
||||
"node_delete_prob",
|
||||
"compatibility_threshold",
|
||||
"bias_init_mean",
|
||||
"bias_init_stdev",
|
||||
"bias_init_std",
|
||||
"bias_mutate_power",
|
||||
"bias_mutate_rate",
|
||||
"bias_replace_rate",
|
||||
"response_init_mean",
|
||||
"response_init_stdev",
|
||||
"response_init_std",
|
||||
"response_mutate_power",
|
||||
"response_mutate_rate",
|
||||
"response_replace_rate",
|
||||
@@ -36,7 +36,7 @@ jit_config_keys = [
|
||||
"aggregation_options",
|
||||
"aggregation_replace_rate",
|
||||
"weight_init_mean",
|
||||
"weight_init_stdev",
|
||||
"weight_init_std",
|
||||
"weight_mutate_power",
|
||||
"weight_mutate_rate",
|
||||
"weight_replace_rate",
|
||||
@@ -90,14 +90,26 @@ class Configer:
|
||||
cls.__check_redundant_config(default_config, config)
|
||||
cls.__complete_config(default_config, config)
|
||||
|
||||
refactor_act(config)
|
||||
refactor_agg(config)
|
||||
input_idx = np.arange(config['num_inputs'])
|
||||
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
config['input_idx'] = input_idx
|
||||
config['output_idx'] = output_idx
|
||||
cls.refactor_activation(config)
|
||||
cls.refactor_aggregation(config)
|
||||
|
||||
config['input_idx'] = np.arange(config['num_inputs'])
|
||||
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def refactor_activation(cls, config):
|
||||
config['activation_default'] = 0
|
||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def refactor_aggregation(cls, config):
|
||||
config['aggregation_default'] = 0
|
||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def create_jit_config(cls, config):
|
||||
jit_config = {k: config[k] for k in jit_config_keys}
|
||||
|
||||
Reference in New Issue
Block a user