This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -36,6 +36,7 @@ def unflatten_conns(nodes, conns):
return unflatten
# TODO: strange implementation
def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim
expand_idx = jnp.expand_dims(
@@ -199,3 +200,14 @@ def delete_conn_by_pos(conns, pos):
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
def hash_array(arr: Array):
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
def update(i, hash_val):
return hash_val ^ (
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
)
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))