bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# add two new connections
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections, weight=1., enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections
@@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
def nothing():
return nodes, connections
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
def successful_delete_node():
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
# check node_key valid
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
return aux_nodes, aux_connections
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, connections
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
:param connection:
:return: from_key, to_key, from_idx, to_idx
"""
k1, k2 = jax.random.split(rand_key, num=2)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(k1, has_connections_row)
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
def nothing():
return jnp.nan, jnp.nan, I_INT, I_INT
def has_connection():
f_idx = fetch_random(k1, has_connections_row)
col = connection[0, f_idx, :]
t_idx = fetch_random(k2, ~jnp.isnan(col))
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
return f_key, t_key, f_idx, t_idx
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
return from_key, to_key, from_idx, to_idx