add "update_by_batch" in gene;

add flatten_conns as an inverse function for unflatten_conns;
add "test_flatten.ipynb" as test for them.
This commit is contained in:
wls2002
2024-05-30 19:44:52 +08:00
parent cd92f411dc
commit 5bd6e5c357
9 changed files with 481 additions and 11 deletions

View File

@@ -13,24 +13,55 @@ def unflatten_conns(nodes, conns):
connection length, N means the number of nodes, C means the number of connections
returns the un_flattened connections with shape (CL-2, N, N)
"""
N = nodes.shape[0]
CL = conns.shape[1]
N = nodes.shape[0] # max_nodes
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs)
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
def key_to_indices(key, keys):
return fetch_first(key == keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((CL - 2, N, N), jnp.nan)
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
# put all attributes include enable in res
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
assert unflatten.shape == (CL - 2, N, N)
return res
return unflatten
def key_to_indices(key, keys):
return fetch_first(key == keys)
def flatten_conns(nodes, unflatten, C):
"""
the inverse function of unflatten_conns
transform the unflatten conn (CL-2, N, N) to (C, CL)
"""
N = nodes.shape[0]
CL = unflatten.shape[0] + 2
node_keys = nodes[:, 0]
def extract_conn(i, j):
return jnp.where(
jnp.isnan(unflatten[0, i, j]),
jnp.nan,
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
)
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
conns = vmap(extract_conn)(x.flatten(), y.flatten())
assert conns.shape == (N * N, CL)
# put nan to the tail of the conns
sorted_idx = jnp.argsort(conns[:, 0])
sorted_conn = conns[sorted_idx]
# truncate the conns to the number of connections
conns = sorted_conn[:C]
assert conns.shape == (C, CL)
return conns
@jit