add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -9,12 +9,12 @@ I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means
|
||||
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the un_flattened connections with shape (CL-2, N, N)
|
||||
returns the unflatten connection indices with shape (N, N)
|
||||
"""
|
||||
N = nodes.shape[0] # max_nodes
|
||||
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs)
|
||||
C = conns.shape[0] # max_conns
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
|
||||
@@ -23,47 +23,25 @@ def unflatten_conns(nodes, conns):
|
||||
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing set values in an array
|
||||
# put all attributes include enable in res
|
||||
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
||||
assert unflatten.shape == (CL - 2, N, N)
|
||||
# however, it will do nothing when setting values in an array
|
||||
# put the index of connections in the unflatten array
|
||||
unflatten = (
|
||||
jnp.full((N, N), I_INF, dtype=jnp.int32)
|
||||
.at[i_idxs, o_idxs]
|
||||
.set(jnp.arange(C, dtype=jnp.int32))
|
||||
)
|
||||
|
||||
return unflatten
|
||||
|
||||
|
||||
def flatten_conns(nodes, unflatten, C):
|
||||
"""
|
||||
the inverse function of unflatten_conns
|
||||
transform the unflatten conn (CL-2, N, N) to (C, CL)
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = unflatten.shape[0] + 2
|
||||
node_keys = nodes[:, 0]
|
||||
|
||||
def extract_conn(i, j):
|
||||
return jnp.where(
|
||||
jnp.isnan(unflatten[0, i, j]),
|
||||
jnp.nan,
|
||||
jnp.concatenate(
|
||||
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
|
||||
),
|
||||
)
|
||||
|
||||
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
|
||||
conns = vmap(extract_conn)(x.flatten(), y.flatten())
|
||||
assert conns.shape == (N * N, CL)
|
||||
|
||||
# put nan to the tail of the conns
|
||||
sorted_idx = jnp.argsort(conns[:, 0])
|
||||
sorted_conn = conns[sorted_idx]
|
||||
|
||||
# truncate the conns to the number of connections
|
||||
conns = sorted_conn[:C]
|
||||
assert conns.shape == (C, CL)
|
||||
return conns
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
expand_idx = jnp.expand_dims(
|
||||
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
|
||||
)
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
|
||||
Reference in New Issue
Block a user