finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -4,8 +4,8 @@ import configparser
import numpy as np
from .activations import refactor_act
from .aggregations import refactor_agg
from neat.genome.activations import act_name2func
from neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
@@ -20,12 +20,12 @@ jit_config_keys = [
"node_delete_prob",
"compatibility_threshold",
"bias_init_mean",
"bias_init_stdev",
"bias_init_std",
"bias_mutate_power",
"bias_mutate_rate",
"bias_replace_rate",
"response_init_mean",
"response_init_stdev",
"response_init_std",
"response_mutate_power",
"response_mutate_rate",
"response_replace_rate",
@@ -36,7 +36,7 @@ jit_config_keys = [
"aggregation_options",
"aggregation_replace_rate",
"weight_init_mean",
"weight_init_stdev",
"weight_init_std",
"weight_mutate_power",
"weight_mutate_rate",
"weight_replace_rate",
@@ -90,14 +90,26 @@ class Configer:
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
refactor_act(config)
refactor_agg(config)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
config['input_idx'] = input_idx
config['output_idx'] = output_idx
cls.refactor_activation(config)
cls.refactor_aggregation(config)
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}