optimize import
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
Reference in New Issue
Block a user