Current Progress: After final design presentation

This commit is contained in:
wls2002
2023-06-19 15:17:56 +08:00
parent acedd67617
commit 5cbe3c14bb
34 changed files with 533 additions and 558 deletions

View File

@@ -44,7 +44,6 @@ def maxabs_agg(z):
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
@@ -71,19 +70,6 @@ def mean_agg(z):
return mean_without_zeros
AGG_TOTAL_LIST = [sum_agg, product_agg, max_agg, min_agg, maxabs_agg, median_agg, mean_agg]
agg_name2key = {
'sum': 0,
'product': 1,
'max': 2,
'min': 3,
'maxabs': 4,
'median': 5,
'mean': 6,
}
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
@@ -97,7 +83,6 @@ def agg(idx, z):
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
vectorized_agg = jax.vmap(agg, in_axes=(0, 0))
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)