complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -10,24 +10,25 @@ EMPTY_CON = np.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes: Array, cons: Array):
def unflatten_connections(nodes: Array, conns: Array):
"""
transform the (C, 4) connections to (2, N, N)
:param nodes: (N, 5)
:param cons: (C, 4)
transform the (C, CL) connections to (CL-2, N, N)
:param nodes: (N, NL)
:param cons: (C, CL)
:return:
"""
N = nodes.shape[0]
CL = conns.shape[1]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_keys, o_keys = conns[:, 0], conns[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
res = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
# put all attributes include enable in res
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
return res
@@ -68,4 +69,4 @@ def rank_elements(array, reverse=False):
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
return jnp.argsort(jnp.argsort(array))