add backend="jax" to sympy module

This commit is contained in:
wls2002
2024-06-13 05:55:33 +08:00
parent b3e442c688
commit 69d73aab73
12 changed files with 254 additions and 167 deletions

View File

@@ -164,7 +164,6 @@ class BaseGenome(StatefulBaseClass):
continue
cd = self.conn_gene.to_dict(state, conn)
in_idx, out_idx = cd["in"], cd["out"]
del cd["in"], cd["out"]
conn_dict[(in_idx, out_idx)] = cd
return conn_dict
@@ -176,7 +175,6 @@ class BaseGenome(StatefulBaseClass):
continue
nd = self.node_gene.to_dict(state, node)
idx = nd["idx"]
del nd["idx"]
node_dict[idx] = nd
return node_dict
@@ -192,7 +190,7 @@ class BaseGenome(StatefulBaseClass):
def get_output_idx(self):
return self.output_idx.tolist()
def sympy_func(self, state, network, precision=3):
def sympy_func(self, state, network, sympy_output_transform=None):
raise NotImplementedError
def visualize(

View File

@@ -1,3 +1,4 @@
import warnings
from typing import Callable
import jax, jax.numpy as jnp
@@ -12,7 +13,8 @@ from utils import (
set_node_attrs,
set_conn_attrs,
attach_with_inf,
FUNCS_MODULE,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
)
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
@@ -191,7 +193,16 @@ class DefaultGenome(BaseGenome):
new_transformed,
)
def sympy_func(self, state, network, precision=3):
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_output_transform is None and self.output_transform is not None:
warnings.warn(
"genome.output_transform is not None but sympy_output_transform is None!"
)
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
@@ -206,6 +217,7 @@ class DefaultGenome(BaseGenome):
nodes_exprs = {}
args_symbols = {}
for i in order:
if i in input_idx:
@@ -215,20 +227,25 @@ class DefaultGenome(BaseGenome):
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
val = self.conn_gene.sympy_func(
# a_s -> args_symbols
val, a_s = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
precision=precision,
)
args_symbols.update(a_s)
node_inputs.append(val)
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
precision=precision,
)
args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]]
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
reduced_exprs = nodes_exprs.copy()
@@ -236,10 +253,31 @@ class DefaultGenome(BaseGenome):
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
sp.lambdify(
input_symbols + list(args_symbols.keys()),
exprs,
modules=[backend, module],
)
for exprs in output_exprs
]
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func
fixed_args_output_funcs = []
for i in range(len(output_idx)):
def f(inputs, i=i):
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
return (
symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,
)