add backend="jax" to sympy module
This commit is contained in:
@@ -39,5 +39,5 @@ class BaseConnGene(BaseGene):
|
||||
"out": int(out_idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
def sympy_func(self, state, conn_dict, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
|
||||
@@ -81,12 +82,10 @@ class DefaultConnGene(BaseConnGene):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"weight": float(conn[2]),
|
||||
"weight": np.array(conn[2], dtype=np.float32),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
weight = conn_dict["weight"]
|
||||
if precision is not None:
|
||||
weight = round(weight, precision)
|
||||
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
|
||||
|
||||
return inputs * weight
|
||||
return inputs * weight, {weight: conn_dict["weight"]}
|
||||
|
||||
Reference in New Issue
Block a user