complete normal neat algorithm

This commit is contained in:
wls2002
2023-07-18 23:55:36 +08:00
parent 40cf0b6fbe
commit 0a2a9fd1be
26 changed files with 880 additions and 251 deletions

View File

@@ -3,6 +3,8 @@ import jax.numpy as jnp
class Aggregation:
name2func = {}
@staticmethod
def sum_agg(z):
z = jnp.where(jnp.isnan(z), 0, z)
@@ -49,12 +51,13 @@ class Aggregation:
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
name2func = {
'sum': sum_agg,
'product': product_agg,
'max': max_agg,
'min': min_agg,
'maxabs': maxabs_agg,
'median': median_agg,
'mean': mean_agg,
}
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}