complete normal neat algorithm
This commit is contained in:
@@ -3,6 +3,8 @@ import jax.numpy as jnp
|
||||
|
||||
class Aggregation:
|
||||
|
||||
name2func = {}
|
||||
|
||||
@staticmethod
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
@@ -49,12 +51,13 @@ class Aggregation:
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
name2func = {
|
||||
'sum': sum_agg,
|
||||
'product': product_agg,
|
||||
'max': max_agg,
|
||||
'min': min_agg,
|
||||
'maxabs': maxabs_agg,
|
||||
'median': median_agg,
|
||||
'mean': mean_agg,
|
||||
}
|
||||
|
||||
Aggregation.name2func = {
|
||||
'sum': Aggregation.sum_agg,
|
||||
'product': Aggregation.product_agg,
|
||||
'max': Aggregation.max_agg,
|
||||
'min': Aggregation.min_agg,
|
||||
'maxabs': Aggregation.maxabs_agg,
|
||||
'median': Aggregation.median_agg,
|
||||
'mean': Aggregation.mean_agg,
|
||||
}
|
||||
Reference in New Issue
Block a user