odify genome for the official release

This commit is contained in:
root
2024-07-10 11:24:11 +08:00
parent 075460f896
commit ee8ec84202
83 changed files with 588 additions and 611 deletions

View File

@@ -0,0 +1,66 @@
import jax
import jax.numpy as jnp
class Agg:
@staticmethod
def name2func(name):
return getattr(Agg, name)
@staticmethod
def sum(z):
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
@staticmethod
def product(z):
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
@staticmethod
def max(z):
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
@staticmethod
def min(z):
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
@staticmethod
def maxabs(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@staticmethod
def mean(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
def agg_func(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
)

View File

@@ -0,0 +1,65 @@
import numpy as np
import sympy as sp
class SympySum(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.sum(z)
class SympyProduct(sp.Function):
@classmethod
def eval(cls, z):
return sp.Mul(*z)
class SympyMax(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z)
class SympyMin(sp.Function):
@classmethod
def eval(cls, z):
return sp.Min(*z)
class SympyMaxabs(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z, key=sp.Abs)
class SympyMean(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)