modify act funcs and sympy act funcs;
add dense and advance initialize genome; add input_transform for genome;
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from .base import BaseGenome
|
from .base import BaseGenome
|
||||||
from .default import DefaultGenome
|
from .default import DefaultGenome
|
||||||
from .recurrent import RecurrentGenome
|
from .recurrent import RecurrentGenome
|
||||||
from .advance import AdvanceInitialize
|
from .advance import AdvanceInitialize
|
||||||
|
from .dense import DenseInitialize
|
||||||
|
|||||||
70
tensorneat/algorithm/neat/genome/advance.py
Normal file
70
tensorneat/algorithm/neat/genome/advance.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
|
from .default import DefaultGenome
|
||||||
|
|
||||||
|
|
||||||
|
class AdvanceInitialize(DefaultGenome):
|
||||||
|
def __init__(self, hidden_cnt=8, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hidden_cnt = hidden_cnt
|
||||||
|
|
||||||
|
def initialize(self, state, randkey):
|
||||||
|
|
||||||
|
k1, k2 = jax.random.split(randkey, num=2)
|
||||||
|
|
||||||
|
input_idx, output_idx = self.input_idx, self.output_idx
|
||||||
|
input_size = len(input_idx)
|
||||||
|
output_size = len(output_idx)
|
||||||
|
|
||||||
|
hidden_idx = jnp.arange(
|
||||||
|
input_size + output_size, input_size + output_size + self.hidden_cnt
|
||||||
|
)
|
||||||
|
nodes = jnp.full(
|
||||||
|
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||||
|
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||||
|
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||||
|
|
||||||
|
total_idx = input_size + output_size + self.hidden_cnt
|
||||||
|
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||||
|
|
||||||
|
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||||
|
node_attrs = node_attr_func(state, rand_keys_n)
|
||||||
|
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
|
||||||
|
|
||||||
|
conns = jnp.full(
|
||||||
|
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
|
||||||
|
input_idx, hidden_idx, indexing="ij"
|
||||||
|
)
|
||||||
|
total_input_to_hidden_conns = input_size * self.hidden_cnt
|
||||||
|
conns = conns.at[:total_input_to_hidden_conns, :2].set(
|
||||||
|
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_to_output_ids, output_ids = jnp.meshgrid(
|
||||||
|
hidden_idx, output_idx, indexing="ij"
|
||||||
|
)
|
||||||
|
total_hidden_to_output_conns = self.hidden_cnt * output_size
|
||||||
|
conns = conns.at[
|
||||||
|
total_input_to_hidden_conns : total_input_to_hidden_conns
|
||||||
|
+ total_hidden_to_output_conns,
|
||||||
|
:2,
|
||||||
|
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
|
||||||
|
|
||||||
|
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
|
||||||
|
rand_keys_c = jax.random.split(k2, num=total_conns)
|
||||||
|
conns_attr_func = jax.vmap(
|
||||||
|
self.conn_gene.new_random_attrs,
|
||||||
|
in_axes=(
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||||
|
conns = conns.at[:total_conns, 2:].set(conns_attrs)
|
||||||
|
|
||||||
|
return nodes, conns
|
||||||
@@ -2,6 +2,7 @@ import warnings
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
from utils import (
|
from utils import (
|
||||||
unflatten_conns,
|
unflatten_conns,
|
||||||
@@ -37,6 +38,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
mutation: BaseMutation = DefaultMutation(),
|
mutation: BaseMutation = DefaultMutation(),
|
||||||
crossover: BaseCrossover = DefaultCrossover(),
|
crossover: BaseCrossover = DefaultCrossover(),
|
||||||
output_transform: Callable = None,
|
output_transform: Callable = None,
|
||||||
|
input_transform: Callable = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inputs,
|
num_inputs,
|
||||||
@@ -49,9 +51,16 @@ class DefaultGenome(BaseGenome):
|
|||||||
crossover,
|
crossover,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if input_transform is not None:
|
||||||
|
try:
|
||||||
|
_ = input_transform(np.zeros(num_inputs))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
|
self.input_transform = input_transform
|
||||||
|
|
||||||
if output_transform is not None:
|
if output_transform is not None:
|
||||||
try:
|
try:
|
||||||
_ = output_transform(jnp.zeros(num_outputs))
|
_ = output_transform(np.zeros(num_outputs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Output transform function failed: {e}")
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
self.output_transform = output_transform
|
self.output_transform = output_transform
|
||||||
@@ -69,6 +78,10 @@ class DefaultGenome(BaseGenome):
|
|||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, transformed, inputs):
|
def forward(self, state, transformed, inputs):
|
||||||
|
|
||||||
|
if self.input_transform is not None:
|
||||||
|
inputs = self.input_transform(inputs)
|
||||||
|
|
||||||
cal_seqs, nodes, conns, u_conns = transformed
|
cal_seqs, nodes, conns, u_conns = transformed
|
||||||
|
|
||||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||||
@@ -118,6 +131,10 @@ class DefaultGenome(BaseGenome):
|
|||||||
return self.output_transform(vals[self.output_idx])
|
return self.output_transform(vals[self.output_idx])
|
||||||
|
|
||||||
def update_by_batch(self, state, batch_input, transformed):
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
|
|
||||||
|
if self.input_transform is not None:
|
||||||
|
batch_input = jax.vmap(self.input_transform)(batch_input)
|
||||||
|
|
||||||
cal_seqs, nodes, conns, u_conns = transformed
|
cal_seqs, nodes, conns, u_conns = transformed
|
||||||
|
|
||||||
batch_size = batch_input.shape[0]
|
batch_size = batch_input.shape[0]
|
||||||
@@ -193,11 +210,19 @@ class DefaultGenome(BaseGenome):
|
|||||||
new_transformed,
|
new_transformed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
|
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
|
||||||
|
|
||||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||||
|
|
||||||
|
if sympy_input_transform is None and self.input_transform is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"genome.input_transform is not None but sympy_input_transform is None!"
|
||||||
|
)
|
||||||
|
if sympy_input_transform is not None:
|
||||||
|
if not isinstance(sympy_input_transform, list):
|
||||||
|
sympy_input_transform = [sympy_input_transform] * self.num_inputs
|
||||||
|
|
||||||
if sympy_output_transform is None and self.output_transform is not None:
|
if sympy_output_transform is None and self.output_transform is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"genome.output_transform is not None but sympy_output_transform is None!"
|
"genome.output_transform is not None but sympy_output_transform is None!"
|
||||||
@@ -221,7 +246,10 @@ class DefaultGenome(BaseGenome):
|
|||||||
for i in order:
|
for i in order:
|
||||||
|
|
||||||
if i in input_idx:
|
if i in input_idx:
|
||||||
nodes_exprs[symbols[i]] = symbols[i]
|
if sympy_input_transform is not None:
|
||||||
|
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
|
||||||
|
else:
|
||||||
|
nodes_exprs[symbols[i]] = symbols[i]
|
||||||
else:
|
else:
|
||||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||||
node_inputs = []
|
node_inputs = []
|
||||||
|
|||||||
56
tensorneat/algorithm/neat/genome/dense.py
Normal file
56
tensorneat/algorithm/neat/genome/dense.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
|
from .default import DefaultGenome
|
||||||
|
|
||||||
|
|
||||||
|
class DenseInitialize(DefaultGenome):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.max_nodes >= self.num_inputs + self.num_outputs
|
||||||
|
assert self.max_conns >= self.num_inputs * self.num_outputs
|
||||||
|
|
||||||
|
def initialize(self, state, randkey):
|
||||||
|
|
||||||
|
k1, k2 = jax.random.split(randkey, num=2)
|
||||||
|
|
||||||
|
input_idx, output_idx = self.input_idx, self.output_idx
|
||||||
|
input_size = len(input_idx)
|
||||||
|
output_size = len(output_idx)
|
||||||
|
|
||||||
|
nodes = jnp.full(
|
||||||
|
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||||
|
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||||
|
|
||||||
|
total_idx = input_size + output_size
|
||||||
|
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||||
|
|
||||||
|
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||||
|
node_attrs = node_attr_func(state, rand_keys_n)
|
||||||
|
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
|
||||||
|
|
||||||
|
conns = jnp.full(
|
||||||
|
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
input_to_output_ids, output_ids = jnp.meshgrid(
|
||||||
|
input_idx, output_idx, indexing="ij"
|
||||||
|
)
|
||||||
|
total_conns = input_size * output_size
|
||||||
|
conns = conns.at[:total_conns, :2].set(
|
||||||
|
jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()])
|
||||||
|
)
|
||||||
|
|
||||||
|
rand_keys_c = jax.random.split(k2, num=total_conns)
|
||||||
|
conns_attr_func = jax.vmap(
|
||||||
|
self.conn_gene.new_random_attrs,
|
||||||
|
in_axes=(
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||||
|
conns = conns.at[:total_conns, 2:].set(conns_attrs)
|
||||||
|
|
||||||
|
return nodes, conns
|
||||||
@@ -20,7 +20,6 @@ name2sympy = {
|
|||||||
"relu": SympyRelu,
|
"relu": SympyRelu,
|
||||||
"lelu": SympyLelu,
|
"lelu": SympyLelu,
|
||||||
"identity": SympyIdentity,
|
"identity": SympyIdentity,
|
||||||
"clamped": SympyClamped,
|
|
||||||
"inv": SympyInv,
|
"inv": SympyInv,
|
||||||
"log": SympyLog,
|
"log": SympyLog,
|
||||||
"exp": SympyExp,
|
"exp": SympyExp,
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
sigma_3 = 2.576
|
||||||
|
|
||||||
|
|
||||||
class Act:
|
class Act:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def name2func(name):
|
def name2func(name):
|
||||||
@@ -9,35 +12,42 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sigmoid(z):
|
def sigmoid(z):
|
||||||
z = jnp.clip(5 * z, -10, 10)
|
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||||
return 1 / (1 + jnp.exp(-z))
|
z = 1 / (1 + jnp.exp(-z))
|
||||||
|
|
||||||
|
return z * sigma_3 # (0, sigma_3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tanh(z):
|
def tanh(z):
|
||||||
z = jnp.clip(0.6*z, -3, 3)
|
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||||
return jnp.tanh(z)
|
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def standard_tanh(z):
|
||||||
|
z = jnp.clip(5 * z / sigma_3, -5, 5)
|
||||||
|
return jnp.tanh(z) # (-1, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sin(z):
|
def sin(z):
|
||||||
return jnp.sin(z)
|
z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
|
||||||
|
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def relu(z):
|
def relu(z):
|
||||||
return jnp.maximum(z, 0)
|
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||||
|
return jnp.maximum(z, 0) # (0, sigma_3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lelu(z):
|
def lelu(z):
|
||||||
leaky = 0.005
|
leaky = 0.005
|
||||||
|
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||||
return jnp.where(z > 0, z, leaky * z)
|
return jnp.where(z > 0, z, leaky * z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def identity(z):
|
def identity(z):
|
||||||
|
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def clamped(z):
|
|
||||||
return jnp.clip(z, -1, 1)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def inv(z):
|
def inv(z):
|
||||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||||
@@ -55,6 +65,7 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def abs(z):
|
def abs(z):
|
||||||
|
z = jnp.clip(z, -1, 1)
|
||||||
return jnp.abs(z)
|
return jnp.abs(z)
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +76,6 @@ ACT_ALL = (
|
|||||||
Act.relu,
|
Act.relu,
|
||||||
Act.lelu,
|
Act.lelu,
|
||||||
Act.identity,
|
Act.identity,
|
||||||
Act.clamped,
|
|
||||||
Act.inv,
|
Act.inv,
|
||||||
Act.log,
|
Act.log,
|
||||||
Act.exp,
|
Act.exp,
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ import sympy as sp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
sigma_3 = 2.576
|
||||||
|
|
||||||
|
|
||||||
class SympyClip(sp.Function):
|
class SympyClip(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, val, min_val, max_val):
|
def eval(cls, val, min_val, max_val):
|
||||||
@@ -26,14 +29,17 @@ class SympySigmoid(sp.Function):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
if z.is_Number:
|
if z.is_Number:
|
||||||
z = SympyClip(5 * z, -10, 10)
|
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||||
return 1 / (1 + sp.exp(-z))
|
z = 1 / (1 + sp.exp(-z))
|
||||||
|
return z * sigma_3
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
z = backend.clip(5 * z, -10, 10)
|
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||||
return 1 / (1 + backend.exp(-z))
|
z = 1 / (1 + backend.exp(-z))
|
||||||
|
|
||||||
|
return z * sigma_3 # (0, sigma_3)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"sigmoid({self.args[0]})"
|
return f"sigmoid({self.args[0]})"
|
||||||
@@ -46,36 +52,56 @@ class SympyTanh(sp.Function):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
if z.is_Number:
|
if z.is_Number:
|
||||||
z = SympyClip(0.6 * z, -3, 3)
|
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||||
|
return sp.tanh(z) * sigma_3
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numerical_eval(z, backend=np):
|
||||||
|
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||||
|
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||||
|
|
||||||
|
|
||||||
|
class SympyStandardTanh(sp.Function):
|
||||||
|
@classmethod
|
||||||
|
def eval(cls, z):
|
||||||
|
if z.is_Number:
|
||||||
|
z = SympyClip(5 * z / sigma_3, -5, 5)
|
||||||
return sp.tanh(z)
|
return sp.tanh(z)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
z = backend.clip(0.6*z, -3, 3)
|
z = backend.clip(5 * z / sigma_3, -5, 5)
|
||||||
return backend.tanh(z)
|
return backend.tanh(z) # (-1, 1)
|
||||||
|
|
||||||
|
|
||||||
class SympySin(sp.Function):
|
class SympySin(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.sin(z)
|
if z.is_Number:
|
||||||
|
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
|
||||||
|
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
return backend.sin(z)
|
z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
|
||||||
|
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||||
|
|
||||||
|
|
||||||
class SympyRelu(sp.Function):
|
class SympyRelu(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
if z.is_Number:
|
if z.is_Number:
|
||||||
return sp.Piecewise((z, z > 0), (0, True))
|
z = SympyClip(z, -sigma_3, sigma_3)
|
||||||
|
return sp.Max(z, 0) # (0, sigma_3)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
return backend.maximum(z, 0)
|
z = backend.clip(z, -sigma_3, sigma_3)
|
||||||
|
return backend.maximum(z, 0) # (0, sigma_3)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"relu({self.args[0]})"
|
return f"relu({self.args[0]})"
|
||||||
@@ -107,21 +133,14 @@ class SympyLelu(sp.Function):
|
|||||||
class SympyIdentity(sp.Function):
|
class SympyIdentity(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return z
|
if z.is_Number:
|
||||||
|
z = SympyClip(z, -sigma_3, sigma_3)
|
||||||
|
return z
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
return z
|
return backend.clip(z, -sigma_3, sigma_3)
|
||||||
|
|
||||||
|
|
||||||
class SympyClamped(sp.Function):
|
|
||||||
@classmethod
|
|
||||||
def eval(cls, z):
|
|
||||||
return SympyClip(z, -1, 1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z, backend=np):
|
|
||||||
return backend.clip(z, -1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class SympyInv(sp.Function):
|
class SympyInv(sp.Function):
|
||||||
|
|||||||
Reference in New Issue
Block a user