use black format all files;
remove "return state" for functions which will be executed in vmap; recover randkey as args in mutation methods
This commit is contained in:
@@ -3,7 +3,6 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
@@ -63,5 +62,5 @@ def agg(idx, z, agg_funcs):
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user