complete normal neat algorithm
This commit is contained in:
@@ -10,24 +10,25 @@ EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
def unflatten_connections(nodes: Array, cons: Array):
|
||||
def unflatten_connections(nodes: Array, conns: Array):
|
||||
"""
|
||||
transform the (C, 4) connections to (2, N, N)
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
transform the (C, CL) connections to (CL-2, N, N)
|
||||
:param nodes: (N, NL)
|
||||
:param cons: (C, CL)
|
||||
:return:
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = conns.shape[1]
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = cons[:, 0], cons[:, 1]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
res = jnp.full((2, N, N), jnp.nan)
|
||||
res = jnp.full((CL - 2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
# put all attributes include enable in res
|
||||
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
||||
|
||||
return res
|
||||
|
||||
@@ -68,4 +69,4 @@ def rank_elements(array, reverse=False):
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
Reference in New Issue
Block a user