finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -1,32 +0,0 @@
from neat.genome.activations import *
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
act_name2key = {
'sigmoid': 0,
'tanh': 1,
'sin': 2,
'gauss': 3,
'relu': 4,
'elu': 5,
'lelu': 6,
'selu': 7,
'softplus': 8,
'identity': 9,
'clamped': 10,
'inv': 11,
'log': 12,
'exp': 13,
'abs': 14,
'hat': 15,
'square': 16,
'cube': 17,
}
def refactor_act(config):
config['activation_default'] = act_name2key[config['activation_default']]
config['activation_options'] = [
act_name2key[act_name] for act_name in config['activation_options']
]

View File

@@ -1,20 +0,0 @@
from neat.genome.aggregations import *
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
def refactor_agg(config):
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
config['aggregation_options'] = [
agg_name2key[act_name] for act_name in config['aggregation_options']
]

View File

@@ -4,8 +4,8 @@ import configparser
import numpy as np
from .activations import refactor_act
from .aggregations import refactor_agg
from neat.genome.activations import act_name2func
from neat.genome.aggregations import agg_name2func
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
jit_config_keys = [
@@ -20,12 +20,12 @@ jit_config_keys = [
"node_delete_prob",
"compatibility_threshold",
"bias_init_mean",
"bias_init_stdev",
"bias_init_std",
"bias_mutate_power",
"bias_mutate_rate",
"bias_replace_rate",
"response_init_mean",
"response_init_stdev",
"response_init_std",
"response_mutate_power",
"response_mutate_rate",
"response_replace_rate",
@@ -36,7 +36,7 @@ jit_config_keys = [
"aggregation_options",
"aggregation_replace_rate",
"weight_init_mean",
"weight_init_stdev",
"weight_init_std",
"weight_mutate_power",
"weight_mutate_rate",
"weight_replace_rate",
@@ -90,14 +90,26 @@ class Configer:
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
refactor_act(config)
refactor_agg(config)
input_idx = np.arange(config['num_inputs'])
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
config['input_idx'] = input_idx
config['output_idx'] = output_idx
cls.refactor_activation(config)
cls.refactor_aggregation(config)
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
@classmethod
def create_jit_config(cls, config):
jit_config = {k: config[k] for k in jit_config_keys}

View File

@@ -5,7 +5,8 @@ init_maximum_nodes = 20
init_maximum_connections = 20
init_maximum_species = 10
expands_coe = 2.0
forward_way = "pop_batch"
forward_way = "pop"
batch_size = 4
[population]
fitness_threshold = 100000
@@ -46,12 +47,12 @@ response_replace_rate = 0.0
[gene-activation]
activation_default = "sigmoid"
activation_options = ["sigmoid"]
activation_option_names = ["sigmoid"]
activation_replace_rate = 0.0
[gene-aggregation]
aggregation_default = "sum"
aggregation_options = ["sum"]
aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0
[gene-weight]