The whole NEAT algorithm is written into functional programming.
This commit is contained in:
@@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if reverse:
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
Reference in New Issue
Block a user