modify act. agg in mutation. they can have option vals
fix a bug in function 'agg'
This commit is contained in:
@@ -88,13 +88,13 @@ agg_name2key = {
|
||||
def agg(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def full_zero():
|
||||
def full_nan():
|
||||
return 0.
|
||||
|
||||
def not_full_zero():
|
||||
def not_full_nan():
|
||||
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
|
||||
|
||||
return jax.lax.cond(jnp.all(z == 0.), full_zero, not_full_zero)
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
||||
|
||||
|
||||
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
|
||||
|
||||
Reference in New Issue
Block a user