finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -1,9 +1,3 @@
"""
aggregations, two special case need to consider:
1. extra 0s
2. full of 0s
"""
import jax
import jax.numpy as jnp
import numpy as np
@@ -44,19 +38,13 @@ def maxabs_agg(z):
@jit
def median_agg(z):
non_zero_mask = ~jnp.isnan(z)
n = jnp.sum(non_zero_mask, axis=0)
non_nan_mask = ~jnp.isnan(z)
n = jnp.sum(non_nan_mask, axis=0)
z = jnp.where(jnp.isnan(z), jnp.inf, z)
sorted_valid_values = jnp.sort(z)
z = jnp.sort(z) # sort
def _even_case():
return (sorted_valid_values[n // 2 - 1] + sorted_valid_values[n // 2]) / 2
def _odd_case():
return sorted_valid_values[n // 2]
median = jax.lax.cond(n % 2 == 0, _even_case, _odd_case)
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@@ -70,25 +58,12 @@ def mean_agg(z):
return mean_without_zeros
@jit
def agg(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
def full_nan():
return 0.
def not_full_nan():
return jax.lax.switch(idx, AGG_TOTAL_LIST, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), full_nan, not_full_nan)
if __name__ == '__main__':
array = jnp.asarray([1, 2, np.nan, np.nan, 3, 4, 5, np.nan, np.nan, np.nan, np.nan], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array))
array2 = jnp.asarray([0, 0, 0, 0], dtype=jnp.float32)
for names in agg_name2key.keys():
print(names, agg(agg_name2key[names], array2))
agg_name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}