add "update_by_batch" in gene;
add flatten_conns as an inverse function for unflatten_conns; add "test_flatten.ipynb" as test for them.
This commit is contained in:
@@ -13,24 +13,55 @@ def unflatten_conns(nodes, conns):
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the un_flattened connections with shape (CL-2, N, N)
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = conns.shape[1]
|
||||
N = nodes.shape[0] # max_nodes
|
||||
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs)
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((CL - 2, N, N), jnp.nan)
|
||||
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
# put all attributes include enable in res
|
||||
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
||||
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
||||
assert unflatten.shape == (CL - 2, N, N)
|
||||
|
||||
return res
|
||||
return unflatten
|
||||
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
def flatten_conns(nodes, unflatten, C):
|
||||
"""
|
||||
the inverse function of unflatten_conns
|
||||
transform the unflatten conn (CL-2, N, N) to (C, CL)
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = unflatten.shape[0] + 2
|
||||
node_keys = nodes[:, 0]
|
||||
|
||||
def extract_conn(i, j):
|
||||
return jnp.where(
|
||||
jnp.isnan(unflatten[0, i, j]),
|
||||
jnp.nan,
|
||||
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
|
||||
)
|
||||
|
||||
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
|
||||
conns = vmap(extract_conn)(x.flatten(), y.flatten())
|
||||
assert conns.shape == (N * N, CL)
|
||||
|
||||
# put nan to the tail of the conns
|
||||
sorted_idx = jnp.argsort(conns[:, 0])
|
||||
sorted_conn = conns[sorted_idx]
|
||||
|
||||
# truncate the conns to the number of connections
|
||||
conns = sorted_conn[:C]
|
||||
assert conns.shape == (C, CL)
|
||||
return conns
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
Reference in New Issue
Block a user