use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -4,13 +4,14 @@ import numpy as np
import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
I_INF = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
:return:
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means
connection length, N means the number of nodes, C means the number of connections
returns the un_flattened connections with shape (CL-2, N, N)
"""
N = nodes.shape[0]
CL = conns.shape[1]
@@ -33,7 +34,7 @@ def key_to_indices(key, keys):
@jit
def fetch_first(mask, default=I_INT) -> Array:
def fetch_first(mask, default=I_INF) -> Array:
"""
fetch the first True index
:param mask: array of bool
@@ -45,18 +46,18 @@ def fetch_first(mask, default=I_INT) -> Array:
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
def fetch_random(randkey, mask, default=I_INF) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
@partial(jit, static_argnames=["reverse"])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
@@ -68,8 +69,17 @@ def rank_elements(array, reverse=False):
@jit
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
def mutate_float(
randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate
):
"""
mutate a float value
uniformly pick r from [0, 1]
r in [0, mutate_rate) -> add noise
r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value
otherwise -> keep the original value
"""
k1, k2, k3 = jax.random.split(randkey, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
@@ -77,30 +87,32 @@ def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, repla
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val),
)
return val
@jit
def mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
def mutate_int(randkey, val, options, replace_rate):
"""
mutate an int value
uniformly pick r from [0, 1]
r in [0, replace_rate) -> create a new value to replace the original value
otherwise -> keep the original value
"""
k1, k2 = jax.random.split(randkey, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val)
return val
def argmin_with_mask(arr, mask):
"""
find the index of the minimum element in the array, but only consider the element with True mask
"""
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
return min_idx