Current Progress: After final design presentation
This commit is contained in:
@@ -44,7 +44,6 @@ def maxabs_agg(z):
|
||||
|
||||
@jit
|
||||
def median_agg(z):
|
||||
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_zero_mask, axis=0)
|
||||
|
||||
@@ -71,19 +70,6 @@ def mean_agg(z):
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
|
||||
|
||||
agg_name2key = {
|
||||
'sum': 0,
|
||||
'product': 1,
|
||||
'max': 2,
|
||||
'min': 3,
|
||||
'maxabs': 4,
|
||||
'median': 5,
|
||||
'mean': 6,
|
||||
}
|
||||
|
||||
|
||||
@jit
|
||||
def agg(idx, z):
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
@@ -97,7 +83,6 @@ def agg(idx, z):
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
||||
|
||||
|
||||
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
|
||||
|
||||
if __name__ == '__main__':
|
||||
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
||||
|
||||
Reference in New Issue
Block a user