add backend="jax" to sympy module
This commit is contained in:
@@ -39,5 +39,5 @@ class BaseConnGene(BaseGene):
|
||||
"out": int(out_idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
def sympy_func(self, state, conn_dict, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
|
||||
@@ -81,12 +82,10 @@ class DefaultConnGene(BaseConnGene):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"weight": float(conn[2]),
|
||||
"weight": np.array(conn[2], dtype=np.float32),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
weight = conn_dict["weight"]
|
||||
if precision is not None:
|
||||
weight = round(weight, precision)
|
||||
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
|
||||
|
||||
return inputs * weight
|
||||
return inputs * weight, {weight: conn_dict["weight"]}
|
||||
|
||||
@@ -54,5 +54,5 @@ class BaseNodeGene(BaseGene):
|
||||
"idx": int(idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from utils import (
|
||||
Act,
|
||||
@@ -160,33 +161,36 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = np.array(bias, dtype=np.float32)
|
||||
res = np.array(res, dtype=np.float32)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"res": float(res),
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
"act": act_func.__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
bias = node_dict["bias"]
|
||||
res = node_dict["res"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
res = round(res, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = bias + z * res
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
|
||||
return z
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
@@ -133,30 +134,36 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
|
||||
bias = np.array(bias, dtype=np.float32)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
"act": act_func.__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
|
||||
bias = node_dict["bias"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
|
||||
return z
|
||||
return z, {bias: nd["bias"]}
|
||||
|
||||
@@ -164,7 +164,6 @@ class BaseGenome(StatefulBaseClass):
|
||||
continue
|
||||
cd = self.conn_gene.to_dict(state, conn)
|
||||
in_idx, out_idx = cd["in"], cd["out"]
|
||||
del cd["in"], cd["out"]
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
@@ -176,7 +175,6 @@ class BaseGenome(StatefulBaseClass):
|
||||
continue
|
||||
nd = self.node_gene.to_dict(state, node)
|
||||
idx = nd["idx"]
|
||||
del nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
|
||||
@@ -192,7 +190,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
def sympy_func(self, state, network, sympy_output_transform=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
@@ -12,7 +13,8 @@ from utils import (
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
attach_with_inf,
|
||||
FUNCS_MODULE,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
)
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
@@ -191,7 +193,16 @@ class DefaultGenome(BaseGenome):
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_output_transform is None and self.output_transform is not None:
|
||||
warnings.warn(
|
||||
"genome.output_transform is not None but sympy_output_transform is None!"
|
||||
)
|
||||
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
@@ -206,6 +217,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
nodes_exprs = {}
|
||||
|
||||
args_symbols = {}
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
@@ -215,20 +227,25 @@ class DefaultGenome(BaseGenome):
|
||||
node_inputs = []
|
||||
for conn in in_conns:
|
||||
val_represent = symbols[conn[0]]
|
||||
val = self.conn_gene.sympy_func(
|
||||
# a_s -> args_symbols
|
||||
val, a_s = self.conn_gene.sympy_func(
|
||||
state,
|
||||
network["conns"][conn],
|
||||
val_represent,
|
||||
precision=precision,
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
node_inputs.append(val)
|
||||
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
|
||||
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
|
||||
state,
|
||||
network["nodes"][i],
|
||||
node_inputs,
|
||||
is_output_node=(i in output_idx),
|
||||
precision=precision,
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
if i in output_idx and sympy_output_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_output_transform(
|
||||
nodes_exprs[symbols[i]]
|
||||
)
|
||||
|
||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
@@ -236,10 +253,31 @@ class DefaultGenome(BaseGenome):
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
|
||||
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
||||
|
||||
lambdify_output_funcs = [
|
||||
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
|
||||
sp.lambdify(
|
||||
input_symbols + list(args_symbols.keys()),
|
||||
exprs,
|
||||
modules=[backend, module],
|
||||
)
|
||||
for exprs in output_exprs
|
||||
]
|
||||
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
|
||||
|
||||
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func
|
||||
fixed_args_output_funcs = []
|
||||
for i in range(len(output_idx)):
|
||||
|
||||
def f(inputs, i=i):
|
||||
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
|
||||
|
||||
fixed_args_output_funcs.append(f)
|
||||
|
||||
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
|
||||
|
||||
return (
|
||||
symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user