debug-branch
This commit is contained in:
@@ -386,18 +386,30 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
# randomly choose a connection
|
||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||
|
||||
# disable the connection
|
||||
connections = connections.at[1, from_idx, to_idx].set(False)
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
|
||||
# add a new node
|
||||
nodes, connections = add_node(new_node_key, nodes, connections,
|
||||
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||
new_idx = fetch_first(nodes[:, 0] == new_node_key)
|
||||
def successful_add_node():
|
||||
# disable the connection
|
||||
new_nodes, new_connections = nodes, connections
|
||||
new_connections = new_connections.at[1, from_idx, to_idx].set(False)
|
||||
|
||||
# add two new connections
|
||||
weight = connections[0, from_idx, to_idx]
|
||||
nodes, connections = add_connection_by_idx(from_idx, new_idx, nodes, connections, weight=0, enabled=True)
|
||||
nodes, connections = add_connection_by_idx(new_idx, to_idx, nodes, connections, weight=weight, enabled=True)
|
||||
# add a new node
|
||||
new_nodes, new_connections = \
|
||||
add_node(new_node_key, new_nodes, new_connections,
|
||||
bias=default_bias, response=default_response, act=default_act, agg=default_agg)
|
||||
new_idx = fetch_first(new_nodes[:, 0] == new_node_key)
|
||||
|
||||
# add two new connections
|
||||
weight = new_connections[0, from_idx, to_idx]
|
||||
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
|
||||
new_nodes, new_connections, weight=0, enabled=True)
|
||||
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
|
||||
new_nodes, new_connections, weight=weight, enabled=True)
|
||||
return new_nodes, new_connections
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -482,7 +494,15 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
"""
|
||||
# randomly choose a connection
|
||||
from_key, to_key, from_idx, to_idx = choice_connection_key(rand_key, nodes, connections)
|
||||
nodes, connections = delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
|
||||
@@ -530,6 +550,10 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
||||
col = connection[0, from_idx, :]
|
||||
to_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
|
||||
|
||||
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
|
||||
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
|
||||
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user