add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -2,6 +2,7 @@ from typing import Tuple
import numpy as np
import jax, jax.numpy as jnp
import sympy as sp
from utils import (
Act,
@@ -160,33 +161,36 @@ class DefaultNodeGene(BaseNodeGene):
def to_dict(self, state, node):
idx, bias, res, agg, act = node
idx = int(idx)
bias = np.array(bias, dtype=np.float32)
res = np.array(res, dtype=np.float32)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return {
"idx": int(idx),
"bias": float(bias),
"res": float(res),
"idx": idx,
"bias": bias,
"res": res,
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
"act": act_func.__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
bias = node_dict["bias"]
res = node_dict["res"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
res = round(res, precision)
z = convert_to_sympy(agg)(inputs)
z = convert_to_sympy(nd["agg"])(inputs)
z = bias + z * res
if is_output_node:
return z
pass
else:
z = convert_to_sympy(act)(z)
z = convert_to_sympy(nd["act"])(z)
return z
return z, {bias: nd["bias"], res: nd["res"]}