Try to reduce difference between sympy formula and network.
Spend a whole night. Failed; I'll never try it anymore.
This commit is contained in:
@@ -9,23 +9,19 @@ class Agg:
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z))
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z))
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z))
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z))
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
|
||||
@@ -7,6 +7,10 @@ class SympySum(sp.Function):
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user