add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -39,5 +39,5 @@ class BaseConnGene(BaseGene):
"out": int(out_idx),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
def sympy_func(self, state, conn_dict, inputs):
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import jax.numpy as jnp
import jax.random
import numpy as np
import sympy as sp
from utils import mutate_float
from . import BaseConnGene
@@ -81,12 +82,10 @@ class DefaultConnGene(BaseConnGene):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"weight": float(conn[2]),
"weight": np.array(conn[2], dtype=np.float32),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
weight = conn_dict["weight"]
if precision is not None:
weight = round(weight, precision)
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
return inputs * weight
return inputs * weight, {weight: conn_dict["weight"]}