bug down! Here it can solve xor successfully!
This commit is contained in:
@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
# add two new connections
|
||||
weight = new_connections[0, from_idx, to_idx]
|
||||
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
|
||||
new_nodes, new_connections, weight=0, enabled=True)
|
||||
new_nodes, new_connections, weight=1., enabled=True)
|
||||
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
|
||||
new_nodes, new_connections, weight=weight, enabled=True)
|
||||
return new_nodes, new_connections
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
# delete the node
|
||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
|
||||
# delete connections
|
||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||
|
||||
# check node_key valid
|
||||
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
|
||||
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
|
||||
# delete connections
|
||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||
|
||||
return aux_nodes, aux_connections
|
||||
|
||||
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
||||
:param connection:
|
||||
:return: from_key, to_key, from_idx, to_idx
|
||||
"""
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
|
||||
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
|
||||
from_idx = fetch_random(k1, has_connections_row)
|
||||
col = connection[0, from_idx, :]
|
||||
to_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
|
||||
|
||||
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
|
||||
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
|
||||
def nothing():
|
||||
return jnp.nan, jnp.nan, I_INT, I_INT
|
||||
|
||||
def has_connection():
|
||||
f_idx = fetch_random(k1, has_connections_row)
|
||||
col = connection[0, f_idx, :]
|
||||
t_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
|
||||
return f_key, t_key, f_idx, t_idx
|
||||
|
||||
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user