modify act. agg in mutation. they can have option vals

fix a bug in function 'agg'
This commit is contained in:
wls2002
2023-05-07 23:00:04 +08:00
parent 47bb593a53
commit b257505bee
4 changed files with 46 additions and 52 deletions

View File

@@ -88,13 +88,13 @@ agg_name2key = {
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_zero():
def full_nan():
return 0.
def not_full_zero():
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))