use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -3,7 +3,6 @@ import jax.numpy as jnp
class Agg:
@staticmethod
def sum(z):
z = jnp.where(jnp.isnan(z), 0, z)
@@ -63,5 +62,5 @@ def agg(idx, z, agg_funcs):
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
)