add backend="jax" to sympy module
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from utils import (
|
||||
Act,
|
||||
@@ -160,33 +161,36 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = np.array(bias, dtype=np.float32)
|
||||
res = np.array(res, dtype=np.float32)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"res": float(res),
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
"act": act_func.__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
bias = node_dict["bias"]
|
||||
res = node_dict["res"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
res = round(res, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = bias + z * res
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
|
||||
return z
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
Reference in New Issue
Block a user