add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -9,12 +9,12 @@ I_INF = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
connection length, N means the number of nodes, C means the number of connections
returns the un_flattened connections with shape (CL-2, N, N)
returns the unflatten connection indices with shape (N, N)
"""
N = nodes.shape[0] # max_nodes
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs)
C = conns.shape[0] # max_conns
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
@@ -23,47 +23,25 @@ def unflatten_conns(nodes, conns):
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
# put all attributes include enable in res
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
assert unflatten.shape == (CL - 2, N, N)
# however, it will do nothing when setting values in an array
# put the index of connections in the unflatten array
unflatten = (
jnp.full((N, N), I_INF, dtype=jnp.int32)
.at[i_idxs, o_idxs]
.set(jnp.arange(C, dtype=jnp.int32))
)
return unflatten
def flatten_conns(nodes, unflatten, C):
"""
the inverse function of unflatten_conns
transform the unflatten conn (CL-2, N, N) to (C, CL)
"""
N = nodes.shape[0]
CL = unflatten.shape[0] + 2
node_keys = nodes[:, 0]
def extract_conn(i, j):
return jnp.where(
jnp.isnan(unflatten[0, i, j]),
jnp.nan,
jnp.concatenate(
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
),
)
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
conns = vmap(extract_conn)(x.flatten(), y.flatten())
assert conns.shape == (N * N, CL)
# put nan to the tail of the conns
sorted_idx = jnp.argsort(conns[:, 0])
sorted_conn = conns[sorted_idx]
# truncate the conns to the number of connections
conns = sorted_conn[:C]
assert conns.shape == (C, CL)
return conns
def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim
expand_idx = jnp.expand_dims(
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
)
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
def extract_node_attrs(node):