又搞到3点,还是没有找到问题在哪,不过已经排除了是forward的问题

This commit is contained in:
wls2002
2023-05-07 02:59:48 +08:00
parent 414b620dc8
commit d1f54022bd
16 changed files with 772 additions and 58 deletions

View File

@@ -9,6 +9,8 @@ from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt = 0, 0, 0, 0
def create_mutate_function(config, input_keys, output_keys, batch: bool):
"""
@@ -79,11 +81,15 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
return mutate_func
else:
def batch_mutate_func(pop_nodes, pop_connections, new_node_keys):
global add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt
add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt = 0, 0, 0, 0
res_nodes, res_connections = [], []
for nodes, connections, new_node_key in zip(pop_nodes, pop_connections, new_node_keys):
nodes, connections = mutate_func(nodes, connections, new_node_key)
res_nodes.append(nodes)
res_connections.append(connections)
# print(f"add_node_cnt: {add_node_cnt}, delete_node_cnt: {delete_node_cnt}, "
# f"add_connection_cnt: {add_connection_cnt}, delete_connection_cnt: {delete_connection_cnt}")
return np.stack(res_nodes, axis=0), np.stack(res_connections, axis=0)
return batch_mutate_func
@@ -161,6 +167,8 @@ def mutate(nodes: NDArray,
:return:
"""
global add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt
# mutate_structure
def nothing(n, c):
return n, c
@@ -200,18 +208,22 @@ def mutate(nodes: NDArray,
# mutate add node
if rand() < add_node_rate:
nodes, connections = m_add_node(nodes, connections)
add_node_cnt += 1
# mutate delete node
if rand() < delete_node_rate:
nodes, connections = m_delete_node(nodes, connections)
delete_node_cnt += 1
# mutate add connection
if rand() < add_connection_rate:
nodes, connections = m_add_connection(nodes, connections)
add_connection_cnt += 1
# mutate delete connection
if rand() < delete_connection_rate:
nodes, connections = m_delete_connection(nodes, connections)
delete_connection_cnt += 1
nodes, connections = mutate_values(nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
@@ -220,6 +232,8 @@ def mutate(nodes: NDArray,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
agg_replace_rate, enabled_reverse_rate)
# print(add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt)
return nodes, connections
@@ -321,9 +335,9 @@ def mutate_float_values(old_vals: NDArray, mean: float, std: float,
replace = np.random.normal(size=old_vals.shape) * std + mean
r = rand(*old_vals.shape)
new_vals = old_vals
new_vals = np.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = np.where(r <= mutate_rate, new_vals + noise, new_vals)
new_vals = np.where(
np.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
(mutate_rate < r) & (r <= mutate_rate + replace_rate),
replace,
new_vals
)
@@ -413,7 +427,7 @@ def mutate_delete_node(nodes: NDArray, connections: NDArray,
node_key, node_idx = choice_node_key(nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
if np.isnan(node_key):
if node_idx == I_INT:
return nodes, connections
# delete the node