remove create_func....
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Aggregation:
|
||||
|
||||
name2func = {}
|
||||
|
||||
@staticmethod
|
||||
@@ -52,12 +52,16 @@ class Aggregation:
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
Aggregation.name2func = {
|
||||
'sum': Aggregation.sum_agg,
|
||||
'product': Aggregation.product_agg,
|
||||
'max': Aggregation.max_agg,
|
||||
'min': Aggregation.min_agg,
|
||||
'maxabs': Aggregation.maxabs_agg,
|
||||
'median': Aggregation.median_agg,
|
||||
'mean': Aggregation.mean_agg,
|
||||
}
|
||||
def agg(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, agg_funcs, z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
Reference in New Issue
Block a user