modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.
This commit is contained in:
@@ -47,7 +47,9 @@ def flatten_conns(nodes, unflatten, C):
|
||||
return jnp.where(
|
||||
jnp.isnan(unflatten[0, i, j]),
|
||||
jnp.nan,
|
||||
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
|
||||
jnp.concatenate(
|
||||
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
|
||||
),
|
||||
)
|
||||
|
||||
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
|
||||
@@ -64,6 +66,40 @@ def flatten_conns(nodes, unflatten, C):
|
||||
return conns
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user