add backend="jax" to sympy module
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
|
||||
@@ -51,15 +52,6 @@ class SympyMedian(sp.Function):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user