finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
from neat.genome.activations import *
|
||||
|
||||
ACT_TOTAL_LIST = [sigmoid_act, tanh_act, sin_act, gauss_act, relu_act, elu_act, lelu_act, selu_act, softplus_act,
|
||||
identity_act, clamped_act, inv_act, log_act, exp_act, abs_act, hat_act, square_act, cube_act]
|
||||
|
||||
act_name2key = {
|
||||
'sigmoid': 0,
|
||||
'tanh': 1,
|
||||
'sin': 2,
|
||||
'gauss': 3,
|
||||
'relu': 4,
|
||||
'elu': 5,
|
||||
'lelu': 6,
|
||||
'selu': 7,
|
||||
'softplus': 8,
|
||||
'identity': 9,
|
||||
'clamped': 10,
|
||||
'inv': 11,
|
||||
'log': 12,
|
||||
'exp': 13,
|
||||
'abs': 14,
|
||||
'hat': 15,
|
||||
'square': 16,
|
||||
'cube': 17,
|
||||
}
|
||||
|
||||
|
||||
def refactor_act(config):
|
||||
config['activation_default'] = act_name2key[config['activation_default']]
|
||||
config['activation_options'] = [
|
||||
act_name2key[act_name] for act_name in config['activation_options']
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from neat.genome.aggregations import *
|
||||
|
||||
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
||||
|
||||
agg_name2key = {
|
||||
'sum': 0,
|
||||
'product': 1,
|
||||
'max': 2,
|
||||
'min': 3,
|
||||
'maxabs': 4,
|
||||
'median': 5,
|
||||
'mean': 6,
|
||||
}
|
||||
|
||||
|
||||
def refactor_agg(config):
|
||||
config['aggregation_default'] = agg_name2key[config['aggregation_default']]
|
||||
config['aggregation_options'] = [
|
||||
agg_name2key[act_name] for act_name in config['aggregation_options']
|
||||
]
|
||||
@@ -4,8 +4,8 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .activations import refactor_act
|
||||
from .aggregations import refactor_agg
|
||||
from neat.genome.activations import act_name2func
|
||||
from neat.genome.aggregations import agg_name2func
|
||||
|
||||
# Configuration used in jit-able functions. The change of values will not cause the re-compilation of JAX.
|
||||
jit_config_keys = [
|
||||
@@ -20,12 +20,12 @@ jit_config_keys = [
|
||||
"node_delete_prob",
|
||||
"compatibility_threshold",
|
||||
"bias_init_mean",
|
||||
"bias_init_stdev",
|
||||
"bias_init_std",
|
||||
"bias_mutate_power",
|
||||
"bias_mutate_rate",
|
||||
"bias_replace_rate",
|
||||
"response_init_mean",
|
||||
"response_init_stdev",
|
||||
"response_init_std",
|
||||
"response_mutate_power",
|
||||
"response_mutate_rate",
|
||||
"response_replace_rate",
|
||||
@@ -36,7 +36,7 @@ jit_config_keys = [
|
||||
"aggregation_options",
|
||||
"aggregation_replace_rate",
|
||||
"weight_init_mean",
|
||||
"weight_init_stdev",
|
||||
"weight_init_std",
|
||||
"weight_mutate_power",
|
||||
"weight_mutate_rate",
|
||||
"weight_replace_rate",
|
||||
@@ -90,14 +90,26 @@ class Configer:
|
||||
cls.__check_redundant_config(default_config, config)
|
||||
cls.__complete_config(default_config, config)
|
||||
|
||||
refactor_act(config)
|
||||
refactor_agg(config)
|
||||
input_idx = np.arange(config['num_inputs'])
|
||||
output_idx = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
config['input_idx'] = input_idx
|
||||
config['output_idx'] = output_idx
|
||||
cls.refactor_activation(config)
|
||||
cls.refactor_aggregation(config)
|
||||
|
||||
config['input_idx'] = np.arange(config['num_inputs'])
|
||||
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def refactor_activation(cls, config):
|
||||
config['activation_default'] = 0
|
||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def refactor_aggregation(cls, config):
|
||||
config['aggregation_default'] = 0
|
||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
@classmethod
|
||||
def create_jit_config(cls, config):
|
||||
jit_config = {k: config[k] for k in jit_config_keys}
|
||||
|
||||
@@ -5,7 +5,8 @@ init_maximum_nodes = 20
|
||||
init_maximum_connections = 20
|
||||
init_maximum_species = 10
|
||||
expands_coe = 2.0
|
||||
forward_way = "pop_batch"
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
|
||||
[population]
|
||||
fitness_threshold = 100000
|
||||
@@ -46,12 +47,12 @@ response_replace_rate = 0.0
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_options = ["sigmoid"]
|
||||
activation_option_names = ["sigmoid"]
|
||||
activation_replace_rate = 0.0
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_options = ["sum"]
|
||||
aggregation_option_names = ["sum"]
|
||||
aggregation_replace_rate = 0.0
|
||||
|
||||
[gene-weight]
|
||||
|
||||
Reference in New Issue
Block a user