95 lines
2.0 KiB
Python
95 lines
2.0 KiB
Python
"""
|
|
aggregations, two special case need to consider:
|
|
1. extra 0s
|
|
2. full of 0s
|
|
"""
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax import jit
|
|
|
|
|
|
@jit
|
|
def sum_agg(z):
|
|
z = jnp.where(jnp.isnan(z), 0, z)
|
|
return jnp.sum(z, axis=0)
|
|
|
|
|
|
@jit
|
|
def product_agg(z):
|
|
z = jnp.where(jnp.isnan(z), 1, z)
|
|
return jnp.prod(z, axis=0)
|
|
|
|
|
|
@jit
|
|
def max_agg(z):
|
|
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
|
return jnp.max(z, axis=0)
|
|
|
|
|
|
@jit
|
|
def min_agg(z):
|
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
|
return jnp.min(z, axis=0)
|
|
|
|
|
|
@jit
|
|
def maxabs_agg(z):
|
|
z = jnp.where(jnp.isnan(z), 0, z)
|
|
abs_z = jnp.abs(z)
|
|
max_abs_index = jnp.argmax(abs_z)
|
|
return z[max_abs_index]
|
|
|
|
|
|
@jit
|
|
def median_agg(z):
|
|
non_zero_mask = ~jnp.isnan(z)
|
|
n = jnp.sum(non_zero_mask, axis=0)
|
|
|
|
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
|
sorted_valid_values = jnp.sort(z)
|
|
|
|
def _even_case():
|
|
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
|
|
|
|
def _odd_case():
|
|
return sorted_valid_values[n // 2]
|
|
|
|
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
|
|
|
|
return median
|
|
|
|
|
|
@jit
|
|
def mean_agg(z):
|
|
non_zero_mask = ~jnp.isnan(z)
|
|
valid_values_sum = sum_agg(z)
|
|
valid_values_count = jnp.sum(non_zero_mask, axis=0)
|
|
mean_without_zeros = valid_values_sum / valid_values_count
|
|
return mean_without_zeros
|
|
|
|
|
|
@jit
|
|
def agg(idx, z):
|
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
|
|
|
def full_nan():
|
|
return 0.
|
|
|
|
def not_full_nan():
|
|
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
|
|
|
|
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
|
|
for names in agg_name2key.keys():
|
|
print(names, agg(agg_name2key[names], array))
|
|
|
|
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
|
|
for names in agg_name2key.keys():
|
|
print(names, agg(agg_name2key[names], array2))
|