create function_factory.py, use to manage functions
This commit is contained in:
@@ -3,3 +3,5 @@ from .distance import create_distance_function
|
||||
from .mutate import create_mutate_function
|
||||
from .forward import create_forward_function
|
||||
from .crossover import create_crossover_function
|
||||
from .activations import act_name2key
|
||||
from .aggregations import agg_name2key
|
||||
|
||||
@@ -33,9 +33,6 @@ def create_distance_function(N, config, type: str, debug: bool = False):
|
||||
else:
|
||||
return res_func
|
||||
|
||||
# return lambda nodes1, connections1, nodes2, connections2: \
|
||||
# distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
|
||||
|
||||
elif type == 'o2m':
|
||||
vmap_func = vmap(distance_with_args, in_axes=(None, None, 0, 0))
|
||||
pop_size = config.neat.population.pop_size
|
||||
|
||||
@@ -45,7 +45,8 @@ def create_initialize_function(config):
|
||||
|
||||
def initialize_genomes(pop_size: int,
|
||||
N: int,
|
||||
num_inputs: int, num_outputs: int,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
default_bias: float = 0.0,
|
||||
default_response: float = 1.0,
|
||||
default_act: int = 0,
|
||||
|
||||
@@ -113,13 +113,11 @@ def mutate(rand_key: Array,
|
||||
new_node_key: int,
|
||||
input_idx: Array,
|
||||
output_idx: Array,
|
||||
bias_default: float = 0,
|
||||
bias_mean: float = 0,
|
||||
bias_std: float = 1,
|
||||
bias_mutate_strength: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
response_default: float = 1,
|
||||
response_mean: float = 1.,
|
||||
response_std: float = 0.,
|
||||
response_mutate_strength: float = 0.,
|
||||
@@ -147,8 +145,6 @@ def mutate(rand_key: Array,
|
||||
:param input_idx:
|
||||
:param agg_default:
|
||||
:param act_default:
|
||||
:param response_default:
|
||||
:param bias_default:
|
||||
:param rand_key:
|
||||
:param nodes: (N, 5)
|
||||
:param connections: (2, N, N)
|
||||
@@ -186,7 +182,7 @@ def mutate(rand_key: Array,
|
||||
return n, c
|
||||
|
||||
def m_add_node(rk, n, c):
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_default, response_default, act_default, agg_default)
|
||||
return mutate_add_node(rk, new_node_key, n, c, bias_mean, response_mean, act_default, agg_default)
|
||||
|
||||
def m_delete_node(rk, n, c):
|
||||
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||
|
||||
Reference in New Issue
Block a user