又搞到3点,还是没有找到问题在哪,不过已经排除了是forward的问题

This commit is contained in:
wls2002
2023-05-07 02:59:48 +08:00
parent 414b620dc8
commit d1f54022bd
16 changed files with 772 additions and 58 deletions

View File

@@ -1,4 +1,4 @@
from .genome import create_initialize_function, expand, expand_single, analysis
from .genome import create_initialize_function, expand, expand_single, analysis, pop_analysis
from .distance import distance
from .mutate import create_mutate_function
from .forward import create_forward_function

View File

@@ -69,7 +69,6 @@ agg_name2key = {
def agg(idx, z):
idx = np.asarray(idx, dtype=np.int32)
if np.all(z == 0.):
return 0
else:

View File

@@ -76,7 +76,6 @@ def forward_single(inputs: NDArray, N: int, input_idx: NDArray, output_idx: NDAr
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
ini_vals[i] = z
return ini_vals[output_idx]

View File

@@ -198,8 +198,13 @@ def analysis(nodes: NDArray, connections: NDArray, input_keys, output_keys) -> \
def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
res = []
total_nodes, total_connections = 0, 0
for nodes, connections in zip(pop_nodes, pop_connections):
res.append(analysis(nodes, connections, input_keys, output_keys))
nodes, connections = analysis(nodes, connections, input_keys, output_keys)
res.append((nodes, connections))
total_nodes += len(nodes)
total_connections += len(connections)
print(total_nodes - 200, total_connections)
return res

View File

@@ -9,6 +9,8 @@ from .utils import fetch_random, fetch_first, I_INT
from .genome import add_node, add_connection_by_idx, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt = 0, 0, 0, 0
def create_mutate_function(config, input_keys, output_keys, batch: bool):
"""
@@ -79,11 +81,15 @@ def create_mutate_function(config, input_keys, output_keys, batch: bool):
return mutate_func
else:
def batch_mutate_func(pop_nodes, pop_connections, new_node_keys):
global add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt
add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt = 0, 0, 0, 0
res_nodes, res_connections = [], []
for nodes, connections, new_node_key in zip(pop_nodes, pop_connections, new_node_keys):
nodes, connections = mutate_func(nodes, connections, new_node_key)
res_nodes.append(nodes)
res_connections.append(connections)
# print(f"add_node_cnt: {add_node_cnt}, delete_node_cnt: {delete_node_cnt}, "
# f"add_connection_cnt: {add_connection_cnt}, delete_connection_cnt: {delete_connection_cnt}")
return np.stack(res_nodes, axis=0), np.stack(res_connections, axis=0)
return batch_mutate_func
@@ -161,6 +167,8 @@ def mutate(nodes: NDArray,
:return:
"""
global add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt
# mutate_structure
def nothing(n, c):
return n, c
@@ -200,18 +208,22 @@ def mutate(nodes: NDArray,
# mutate add node
if rand() < add_node_rate:
nodes, connections = m_add_node(nodes, connections)
add_node_cnt += 1
# mutate delete node
if rand() < delete_node_rate:
nodes, connections = m_delete_node(nodes, connections)
delete_node_cnt += 1
# mutate add connection
if rand() < add_connection_rate:
nodes, connections = m_add_connection(nodes, connections)
add_connection_cnt += 1
# mutate delete connection
if rand() < delete_connection_rate:
nodes, connections = m_delete_connection(nodes, connections)
delete_connection_cnt += 1
nodes, connections = mutate_values(nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
@@ -220,6 +232,8 @@ def mutate(nodes: NDArray,
weight_mutate_rate, weight_replace_rate, act_range, act_replace_rate, agg_range,
agg_replace_rate, enabled_reverse_rate)
# print(add_node_cnt, delete_node_cnt, add_connection_cnt, delete_connection_cnt)
return nodes, connections
@@ -321,9 +335,9 @@ def mutate_float_values(old_vals: NDArray, mean: float, std: float,
replace = np.random.normal(size=old_vals.shape) * std + mean
r = rand(*old_vals.shape)
new_vals = old_vals
new_vals = np.where(r < mutate_rate, new_vals + noise, new_vals)
new_vals = np.where(r <= mutate_rate, new_vals + noise, new_vals)
new_vals = np.where(
np.logical_and(mutate_rate < r, r < mutate_rate + replace_rate),
(mutate_rate < r) & (r <= mutate_rate + replace_rate),
replace,
new_vals
)
@@ -413,7 +427,7 @@ def mutate_delete_node(nodes: NDArray, connections: NDArray,
node_key, node_idx = choice_node_key(nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
if np.isnan(node_key):
if node_idx == I_INT:
return nodes, connections
# delete the node