debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -386,18 +386,30 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
# disable the connection
connections = connections.at[1, from_idx, to_idx].set(False)
def nothing():
return nodes, connections
# add a new node
nodes, connections = add_node(new_node_key, nodes, connections,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(nodes[:, 0] == new_node_key)
def successful_add_node():
# disable the connection
new_nodes, new_connections = nodes, connections
new_connections = new_connections.at[1, from_idx, to_idx].set(False)
# add two new connections
weight = connections[0, from_idx, to_idx]
nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True)
nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True)
# add a new node
new_nodes, new_connections = \
add_node(new_node_key, new_nodes, new_connections,
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
# add two new connections
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections
@@ -482,7 +494,15 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
"""
# randomly choose a connection
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections)
def nothing():
return nodes, connections
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections
@@ -530,6 +550,10 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
return from_key, to_key, from_idx, to_idx