Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -1,34 +1,27 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
@@ -36,7 +29,6 @@ def maxabs_agg(z):
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
@jit
|
||||
def median_agg(z):
|
||||
non_nan_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_nan_mask, axis=0)
|
||||
@@ -49,7 +41,6 @@ def median_agg(z):
|
||||
return median
|
||||
|
||||
|
||||
@jit
|
||||
def mean_agg(z):
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
valid_values_sum = sum_agg(z)
|
||||
|
||||
Reference in New Issue
Block a user