use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -3,7 +3,6 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Act:
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = jnp.clip(5 * z, -10, 10)
|
||||
@@ -36,11 +35,7 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.where(
|
||||
z > 0,
|
||||
jnp.maximum(z, 1e-7),
|
||||
jnp.minimum(z, -1e-7)
|
||||
)
|
||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -3,7 +3,6 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
@@ -63,5 +62,5 @@ def agg(idx, z, agg_funcs):
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ Only used in feed-forward networks.
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from .tools import fetch_first, I_INT
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
|
||||
@jit
|
||||
@@ -17,16 +17,16 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
res = jnp.full(in_degree.shape, I_INF)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
return i != I_INT
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
return i != I_INF
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.)
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
@@ -65,4 +65,4 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
return visited[from_idx]
|
||||
|
||||
@@ -3,9 +3,8 @@ from jax.tree_util import register_pytree_node_class
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__['state_dict'] = kwargs
|
||||
self.__dict__["state_dict"] = kwargs
|
||||
|
||||
def registered_keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
@@ -4,13 +4,14 @@ import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
|
||||
:return:
|
||||
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the un_flattened connections with shape (CL-2, N, N)
|
||||
"""
|
||||
N = nodes.shape[0]
|
||||
CL = conns.shape[1]
|
||||
@@ -33,7 +34,7 @@ def key_to_indices(key, keys):
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INT) -> Array:
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
@@ -45,18 +46,18 @@ def fetch_first(mask, default=I_INT) -> Array:
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
def fetch_random(randkey, mask, default=I_INF) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
@partial(jit, static_argnames=["reverse"])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
@@ -68,8 +69,17 @@ def rank_elements(array, reverse=False):
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
|
||||
k1, k2, k3 = jax.random.split(key, num=3)
|
||||
def mutate_float(
|
||||
randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate
|
||||
):
|
||||
"""
|
||||
mutate a float value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, mutate_rate) -> add noise
|
||||
r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
@@ -77,30 +87,32 @@ def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, repla
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where(
|
||||
(mutate_rate < r) & (r < mutate_rate + replace_rate),
|
||||
replace,
|
||||
val
|
||||
)
|
||||
jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val),
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int(key, val, options, replace_rate):
|
||||
k1, k2 = jax.random.split(key, num=2)
|
||||
def mutate_int(randkey, val, options, replace_rate):
|
||||
"""
|
||||
mutate an int value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < replace_rate,
|
||||
jax.random.choice(k2, options),
|
||||
val
|
||||
)
|
||||
val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
"""
|
||||
find the index of the minimum element in the array, but only consider the element with True mask
|
||||
"""
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
return min_idx
|
||||
|
||||
Reference in New Issue
Block a user