finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -1,5 +1,4 @@
from functools import partial
from typing import Tuple
import jax
from jax import numpy as jnp, Array
@@ -11,20 +10,18 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def unflatten_connections(nodes, cons):
def unflatten_connections(nodes: Array, cons: Array):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:param nodes: (N, 5)
:param cons: (C, 4)
:return:
"""
N = nodes.shape[0]
node_keys = nodes[:, 0]
i_keys, o_keys = cons[:, 0], cons[:, 1]
i_idxs = key_to_indices(i_keys, node_keys)
o_idxs = key_to_indices(o_keys, node_keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
@@ -34,8 +31,6 @@ def unflatten_connections(nodes, cons):
return res
@partial(vmap, in_axes=(0, None))
def key_to_indices(key, keys):
return fetch_first(key == keys)
@@ -46,27 +41,12 @@ def fetch_first(mask, default=I_INT) -> Array:
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return I_INT
example:
>>> a = jnp.array([1, 2, 3, 4, 5])
>>> fetch_first(a > 3)
3
>>> fetch_first(a > 30)
I_INT
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_last(mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch the last True index
"""
reversed_idx = fetch_first(mask[::-1], default)
return jnp.where(reversed_idx == -1, -1, mask.shape[0] - reversed_idx - 1)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
@@ -78,27 +58,8 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@jit
def argmin_with_mask(arr: Array, mask: Array) -> Array:
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
if __name__ == '__main__':
a = jnp.array([1, 2, 3, 4, 5])
print(fetch_first(a > 3))
print(fetch_first(a > 30))
print(fetch_last(a > 3))
print(fetch_last(a > 30))
rand_key = jax.random.PRNGKey(0)
for t in [-1, 0, 1, 2, 3, 4, 5]:
for _ in range(10):
rand_key, _ = jax.random.split(rand_key)
print(jax.random.randint(rand_key, shape=(), minval=1, maxval=2))
print(t, fetch_random(rand_key, a > t))
return min_idx