fix bugs
This commit is contained in:
@@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns):
|
||||
return unflatten
|
||||
|
||||
|
||||
# TODO: strange implementation
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
@@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos):
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
|
||||
Reference in New Issue
Block a user