Current Progress: After final design presentation
This commit is contained in:
@@ -32,9 +32,6 @@ def unflatten_connections(nodes, cons):
|
||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||
|
||||
# (2, N, N), (2, N, N), (2, N, N)
|
||||
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -88,6 +85,7 @@ def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
a = jnp.array([1, 2, 3, 4, 5])
|
||||
|
||||
Reference in New Issue
Block a user