modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.

This commit is contained in:
wls2002
2024-06-01 20:42:42 +08:00
parent 4ad9f0a85a
commit e65200a94e
14 changed files with 281 additions and 204 deletions

View File

@@ -47,7 +47,9 @@ def flatten_conns(nodes, unflatten, C):
return jnp.where(
jnp.isnan(unflatten[0, i, j]),
jnp.nan,
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
jnp.concatenate(
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
),
)
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
@@ -64,6 +66,40 @@ def flatten_conns(nodes, unflatten, C):
return conns
def extract_node_attrs(node):
"""
node: Array(NL, )
extract the attributes of a node
"""
return node[1:] # 0 is for idx
def set_node_attrs(node, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
"""
return node.at[1:].set(attrs) # 0 is for idx
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
@jit
def fetch_first(mask, default=I_INF) -> Array:
"""