add backend="jax" to sympy module
This commit is contained in:
@@ -39,5 +39,5 @@ class BaseConnGene(BaseGene):
|
|||||||
"out": int(out_idx),
|
"out": int(out_idx),
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
def sympy_func(self, state, conn_dict, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.random
|
import jax.random
|
||||||
|
import numpy as np
|
||||||
|
import sympy as sp
|
||||||
from utils import mutate_float
|
from utils import mutate_float
|
||||||
from . import BaseConnGene
|
from . import BaseConnGene
|
||||||
|
|
||||||
@@ -81,12 +82,10 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
return {
|
return {
|
||||||
"in": int(conn[0]),
|
"in": int(conn[0]),
|
||||||
"out": int(conn[1]),
|
"out": int(conn[1]),
|
||||||
"weight": float(conn[2]),
|
"weight": np.array(conn[2], dtype=np.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||||
weight = conn_dict["weight"]
|
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
|
||||||
if precision is not None:
|
|
||||||
weight = round(weight, precision)
|
|
||||||
|
|
||||||
return inputs * weight
|
return inputs * weight, {weight: conn_dict["weight"]}
|
||||||
|
|||||||
@@ -54,5 +54,5 @@ class BaseNodeGene(BaseGene):
|
|||||||
"idx": int(idx),
|
"idx": int(idx),
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
|
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
import sympy as sp
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
Act,
|
Act,
|
||||||
@@ -160,33 +161,36 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
|
|
||||||
def to_dict(self, state, node):
|
def to_dict(self, state, node):
|
||||||
idx, bias, res, agg, act = node
|
idx, bias, res, agg, act = node
|
||||||
|
|
||||||
|
idx = int(idx)
|
||||||
|
bias = np.array(bias, dtype=np.float32)
|
||||||
|
res = np.array(res, dtype=np.float32)
|
||||||
|
agg = int(agg)
|
||||||
|
act = int(act)
|
||||||
|
|
||||||
|
if act == -1:
|
||||||
|
act_func = Act.identity
|
||||||
|
else:
|
||||||
|
act_func = self.activation_options[act]
|
||||||
return {
|
return {
|
||||||
"idx": int(idx),
|
"idx": idx,
|
||||||
"bias": float(bias),
|
"bias": bias,
|
||||||
"res": float(res),
|
"res": res,
|
||||||
"agg": self.aggregation_options[int(agg)].__name__,
|
"agg": self.aggregation_options[int(agg)].__name__,
|
||||||
"act": self.activation_options[int(act)].__name__,
|
"act": act_func.__name__,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(
|
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
nd = node_dict
|
||||||
):
|
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||||
|
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||||
|
|
||||||
bias = node_dict["bias"]
|
z = convert_to_sympy(nd["agg"])(inputs)
|
||||||
res = node_dict["res"]
|
|
||||||
agg = node_dict["agg"]
|
|
||||||
act = node_dict["act"]
|
|
||||||
|
|
||||||
if precision is not None:
|
|
||||||
bias = round(bias, precision)
|
|
||||||
res = round(res, precision)
|
|
||||||
|
|
||||||
z = convert_to_sympy(agg)(inputs)
|
|
||||||
z = bias + z * res
|
z = bias + z * res
|
||||||
|
|
||||||
if is_output_node:
|
if is_output_node:
|
||||||
return z
|
pass
|
||||||
else:
|
else:
|
||||||
z = convert_to_sympy(act)(z)
|
z = convert_to_sympy(nd["act"])(z)
|
||||||
|
|
||||||
return z
|
return z, {bias: nd["bias"], res: nd["res"]}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import sympy as sp
|
||||||
from utils import (
|
from utils import (
|
||||||
Act,
|
Act,
|
||||||
Agg,
|
Agg,
|
||||||
@@ -133,30 +134,36 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
|||||||
|
|
||||||
def to_dict(self, state, node):
|
def to_dict(self, state, node):
|
||||||
idx, bias, agg, act = node
|
idx, bias, agg, act = node
|
||||||
|
|
||||||
|
idx = int(idx)
|
||||||
|
|
||||||
|
bias = np.array(bias, dtype=np.float32)
|
||||||
|
agg = int(agg)
|
||||||
|
act = int(act)
|
||||||
|
|
||||||
|
if act == -1:
|
||||||
|
act_func = Act.identity
|
||||||
|
else:
|
||||||
|
act_func = self.activation_options[act]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"idx": int(idx),
|
"idx": idx,
|
||||||
"bias": float(bias),
|
"bias": bias,
|
||||||
"agg": self.aggregation_options[int(agg)].__name__,
|
"agg": self.aggregation_options[int(agg)].__name__,
|
||||||
"act": self.activation_options[int(act)].__name__,
|
"act": act_func.__name__,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sympy_func(
|
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
nd = node_dict
|
||||||
):
|
|
||||||
|
|
||||||
bias = node_dict["bias"]
|
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||||
agg = node_dict["agg"]
|
|
||||||
act = node_dict["act"]
|
|
||||||
|
|
||||||
if precision is not None:
|
z = convert_to_sympy(nd["agg"])(inputs)
|
||||||
bias = round(bias, precision)
|
|
||||||
|
|
||||||
z = convert_to_sympy(agg)(inputs)
|
|
||||||
z = bias + z
|
z = bias + z
|
||||||
|
|
||||||
if is_output_node:
|
if is_output_node:
|
||||||
return z
|
pass
|
||||||
else:
|
else:
|
||||||
z = convert_to_sympy(act)(z)
|
z = convert_to_sympy(nd["act"])(z)
|
||||||
|
|
||||||
return z
|
return z, {bias: nd["bias"]}
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
continue
|
continue
|
||||||
cd = self.conn_gene.to_dict(state, conn)
|
cd = self.conn_gene.to_dict(state, conn)
|
||||||
in_idx, out_idx = cd["in"], cd["out"]
|
in_idx, out_idx = cd["in"], cd["out"]
|
||||||
del cd["in"], cd["out"]
|
|
||||||
conn_dict[(in_idx, out_idx)] = cd
|
conn_dict[(in_idx, out_idx)] = cd
|
||||||
return conn_dict
|
return conn_dict
|
||||||
|
|
||||||
@@ -176,7 +175,6 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
continue
|
continue
|
||||||
nd = self.node_gene.to_dict(state, node)
|
nd = self.node_gene.to_dict(state, node)
|
||||||
idx = nd["idx"]
|
idx = nd["idx"]
|
||||||
del nd["idx"]
|
|
||||||
node_dict[idx] = nd
|
node_dict[idx] = nd
|
||||||
return node_dict
|
return node_dict
|
||||||
|
|
||||||
@@ -192,7 +190,7 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
def get_output_idx(self):
|
def get_output_idx(self):
|
||||||
return self.output_idx.tolist()
|
return self.output_idx.tolist()
|
||||||
|
|
||||||
def sympy_func(self, state, network, precision=3):
|
def sympy_func(self, state, network, sympy_output_transform=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
@@ -12,7 +13,8 @@ from utils import (
|
|||||||
set_node_attrs,
|
set_node_attrs,
|
||||||
set_conn_attrs,
|
set_conn_attrs,
|
||||||
attach_with_inf,
|
attach_with_inf,
|
||||||
FUNCS_MODULE,
|
SYMPY_FUNCS_MODULE_NP,
|
||||||
|
SYMPY_FUNCS_MODULE_JNP,
|
||||||
)
|
)
|
||||||
from . import BaseGenome
|
from . import BaseGenome
|
||||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||||
@@ -191,7 +193,16 @@ class DefaultGenome(BaseGenome):
|
|||||||
new_transformed,
|
new_transformed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sympy_func(self, state, network, precision=3):
|
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
|
||||||
|
|
||||||
|
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||||
|
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||||
|
|
||||||
|
if sympy_output_transform is None and self.output_transform is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"genome.output_transform is not None but sympy_output_transform is None!"
|
||||||
|
)
|
||||||
|
|
||||||
input_idx = self.get_input_idx()
|
input_idx = self.get_input_idx()
|
||||||
output_idx = self.get_output_idx()
|
output_idx = self.get_output_idx()
|
||||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||||
@@ -206,6 +217,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
nodes_exprs = {}
|
nodes_exprs = {}
|
||||||
|
|
||||||
|
args_symbols = {}
|
||||||
for i in order:
|
for i in order:
|
||||||
|
|
||||||
if i in input_idx:
|
if i in input_idx:
|
||||||
@@ -215,20 +227,25 @@ class DefaultGenome(BaseGenome):
|
|||||||
node_inputs = []
|
node_inputs = []
|
||||||
for conn in in_conns:
|
for conn in in_conns:
|
||||||
val_represent = symbols[conn[0]]
|
val_represent = symbols[conn[0]]
|
||||||
val = self.conn_gene.sympy_func(
|
# a_s -> args_symbols
|
||||||
|
val, a_s = self.conn_gene.sympy_func(
|
||||||
state,
|
state,
|
||||||
network["conns"][conn],
|
network["conns"][conn],
|
||||||
val_represent,
|
val_represent,
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
args_symbols.update(a_s)
|
||||||
node_inputs.append(val)
|
node_inputs.append(val)
|
||||||
nodes_exprs[symbols[i]] = self.node_gene.sympy_func(
|
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
|
||||||
state,
|
state,
|
||||||
network["nodes"][i],
|
network["nodes"][i],
|
||||||
node_inputs,
|
node_inputs,
|
||||||
is_output_node=(i in output_idx),
|
is_output_node=(i in output_idx),
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
args_symbols.update(a_s)
|
||||||
|
if i in output_idx and sympy_output_transform is not None:
|
||||||
|
nodes_exprs[symbols[i]] = sympy_output_transform(
|
||||||
|
nodes_exprs[symbols[i]]
|
||||||
|
)
|
||||||
|
|
||||||
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
||||||
reduced_exprs = nodes_exprs.copy()
|
reduced_exprs = nodes_exprs.copy()
|
||||||
@@ -236,10 +253,31 @@ class DefaultGenome(BaseGenome):
|
|||||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||||
|
|
||||||
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
||||||
|
|
||||||
lambdify_output_funcs = [
|
lambdify_output_funcs = [
|
||||||
sp.lambdify(input_symbols, exprs, modules=["numpy", FUNCS_MODULE])
|
sp.lambdify(
|
||||||
|
input_symbols + list(args_symbols.keys()),
|
||||||
|
exprs,
|
||||||
|
modules=[backend, module],
|
||||||
|
)
|
||||||
for exprs in output_exprs
|
for exprs in output_exprs
|
||||||
]
|
]
|
||||||
forward_func = lambda inputs: [f(*inputs) for f in lambdify_output_funcs]
|
|
||||||
|
|
||||||
return symbols, input_symbols, nodes_exprs, output_exprs, forward_func
|
fixed_args_output_funcs = []
|
||||||
|
for i in range(len(output_idx)):
|
||||||
|
|
||||||
|
def f(inputs, i=i):
|
||||||
|
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
|
||||||
|
|
||||||
|
fixed_args_output_funcs.append(f)
|
||||||
|
|
||||||
|
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
|
||||||
|
|
||||||
|
return (
|
||||||
|
symbols,
|
||||||
|
args_symbols,
|
||||||
|
input_symbols,
|
||||||
|
nodes_exprs,
|
||||||
|
output_exprs,
|
||||||
|
forward_func,
|
||||||
|
)
|
||||||
|
|||||||
@@ -11,13 +11,15 @@
|
|||||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||||
"from utils.graph import topological_sort_python\n",
|
"from utils.graph import topological_sort_python\n",
|
||||||
"from utils import Act, Agg"
|
"from utils import Act, Agg\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:46.886073700Z",
|
"end_time": "2024-06-12T21:48:58.065855900Z",
|
||||||
"start_time": "2024-06-12T11:35:46.042288800Z"
|
"start_time": "2024-06-12T21:48:57.292767Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "9531a569d9ecf774"
|
"id": "9531a569d9ecf774"
|
||||||
@@ -29,8 +31,8 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"genome = AdvanceInitialize(\n",
|
"genome = AdvanceInitialize(\n",
|
||||||
" num_inputs=3,\n",
|
" num_inputs=3,\n",
|
||||||
" num_outputs=1,\n",
|
" num_outputs=3,\n",
|
||||||
" hidden_cnt=1,\n",
|
" hidden_cnt=2,\n",
|
||||||
" max_nodes=50,\n",
|
" max_nodes=50,\n",
|
||||||
" max_conns=500,\n",
|
" max_conns=500,\n",
|
||||||
" node_gene=NodeGeneWithoutResponse(\n",
|
" node_gene=NodeGeneWithoutResponse(\n",
|
||||||
@@ -38,7 +40,8 @@
|
|||||||
" aggregation_default=Agg.sum,\n",
|
" aggregation_default=Agg.sum,\n",
|
||||||
" # activation_options=(Act.tanh,),\n",
|
" # activation_options=(Act.tanh,),\n",
|
||||||
" aggregation_options=(Agg.sum,),\n",
|
" aggregation_options=(Agg.sum,),\n",
|
||||||
" )\n",
|
" ),\n",
|
||||||
|
" output_transform=jnp.tanh,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"state = genome.setup()\n",
|
"state = genome.setup()\n",
|
||||||
@@ -51,8 +54,8 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:52.274062400Z",
|
"end_time": "2024-06-12T21:49:03.858545Z",
|
||||||
"start_time": "2024-06-12T11:35:46.892042200Z"
|
"start_time": "2024-06-12T21:48:58.071859800Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "4013c9f9d5472eb7"
|
"id": "4013c9f9d5472eb7"
|
||||||
@@ -63,7 +66,7 @@
|
|||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]"
|
"text/plain": "{'nodes': {0: {'idx': 0,\n 'bias': array(0.22059791, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 1: {'idx': 1,\n 'bias': array(0.7715081, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 2: {'idx': 2,\n 'bias': array(1.1184921, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 3: {'idx': 3,\n 'bias': array(0.6967973, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 4: {'idx': 4,\n 'bias': array(0.85948837, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 5: {'idx': 5,\n 'bias': array(0.19332138, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 6: {'idx': 6,\n 'bias': array(-0.31763914, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'},\n 7: {'idx': 7,\n 'bias': array(0.05656302, dtype=float32),\n 'agg': 'sum',\n 'act': 'sigmoid'}},\n 'conns': {(0, 6): {'in': 0,\n 'out': 6,\n 'weight': array(1.6676894, dtype=float32)},\n (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},\n (1, 6): {'in': 1, 'out': 6, 'weight': array(0.10137014, dtype=float32)},\n (1, 7): {'in': 1, 'out': 7, 'weight': array(-0.12093307, dtype=float32)},\n (2, 6): {'in': 2, 'out': 6, 'weight': array(-1.8677292, dtype=float32)},\n (2, 7): {'in': 2, 'out': 7, 'weight': array(-0.4195783, dtype=float32)},\n (6, 3): {'in': 6, 'out': 3, 'weight': array(1.2615877, dtype=float32)},\n (6, 4): {'in': 6, 'out': 4, 'weight': array(-0.27593768, dtype=float32)},\n (6, 5): {'in': 6, 'out': 5, 'weight': array(-0.5819819, dtype=float32)},\n (7, 3): {'in': 7, 'out': 3, 'weight': array(0.59301573, dtype=float32)},\n (7, 4): {'in': 7, 'out': 4, 'weight': array(0.19493186, dtype=float32)},\n (7, 5): {'in': 7, 'out': 5, 'weight': array(0.18183969, dtype=float32)}}}"
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -71,89 +74,72 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import sympy as sp\n",
|
"network"
|
||||||
"\n",
|
|
||||||
"symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )\n",
|
|
||||||
"output_exprs"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:52.325161800Z",
|
"end_time": "2024-06-12T21:49:03.873543600Z",
|
||||||
"start_time": "2024-06-12T11:35:52.282008300Z"
|
"start_time": "2024-06-12T21:49:03.867543Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "188006cebb04847"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sympy as sp\n",
|
||||||
|
"\n",
|
||||||
|
"# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)\n",
|
||||||
|
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)\n",
|
||||||
|
"symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-12T21:50:37.527882500Z",
|
||||||
|
"start_time": "2024-06-12T21:50:37.518559400Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "addea793fc002900"
|
"id": "addea793fc002900"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 12,
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"- 0.535 \\mathrm{sigmoid}\\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\\right) - 0.264\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(sp.latex(output_exprs[0]))"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"ExecuteTime": {
|
|
||||||
"end_time": "2024-06-12T11:35:52.341639700Z",
|
|
||||||
"start_time": "2024-06-12T11:35:52.323163700Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"id": "967cb87e24373f77"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
},
|
|
||||||
"id": "88eee4db9eb857cd"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "[-0.7940936986556304]"
|
"text/plain": "(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))"
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"random_inputs = np.random.randn(3).astype(np.float32)\n",
|
||||||
"random_inputs = np.random.randn(3)\n",
|
"random_inputs, random_inputs.dtype"
|
||||||
"res = forward_func(random_inputs)\n",
|
|
||||||
"res "
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:52.342638Z",
|
"end_time": "2024-06-12T21:50:38.178769100Z",
|
||||||
"start_time": "2024-06-12T11:35:52.330160600Z"
|
"start_time": "2024-06-12T21:50:38.155744Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "c5581201d990ba1c"
|
"id": "3aa7c874f3a5743f"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 13,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "Array([-0.7934886], dtype=float32, weak_type=True)"
|
"text/plain": "Array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32, weak_type=True)"
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -166,25 +152,89 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:53.273851900Z",
|
"end_time": "2024-06-12T21:50:48.747287900Z",
|
||||||
"start_time": "2024-06-12T11:35:52.384588600Z"
|
"start_time": "2024-06-12T21:50:48.560675400Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "fe3449a5bc688bc3"
|
"id": "fe3449a5bc688bc3"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 14,
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
"source": [],
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32),\n array([ 0.9743453 , 0.57646036, -0.3080282 ], dtype=float32),\n array([ 0.9743453, 0.5764604, -0.3080282], dtype=float32))"
|
||||||
|
},
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)\n",
|
||||||
|
"res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)\n",
|
||||||
|
"res = np.array(genome.forward(state, transformed, random_inputs))\n",
|
||||||
|
"res1, res2, res"
|
||||||
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-06-12T11:35:53.274854100Z",
|
"end_time": "2024-06-12T21:51:15.098948600Z",
|
||||||
"start_time": "2024-06-12T11:35:53.265856700Z"
|
"start_time": "2024-06-12T21:51:14.908948500Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "174c7dc3d9499f95"
|
"id": "a874d434509f1092"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "(array([ True, True, True]), array([ True, False, True]))"
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"res1 == res, res2 == res"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-12T21:51:25.857465200Z",
|
||||||
|
"start_time": "2024-06-12T21:51:25.851465300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "d226e5bd6e2d44d6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "array([False, False, True])"
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-12T21:00:19.851215800Z",
|
||||||
|
"start_time": "2024-06-12T21:00:19.836443700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "2a36ce6afc59ee8a"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
@@ -28,3 +28,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
print(genome.repr(state, nodes, conns))
|
print(genome.repr(state, nodes, conns))
|
||||||
print(network)
|
print(network)
|
||||||
|
|
||||||
|
res = genome.sympy_func(state, network, precision=3)
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||||
from .tools import *
|
from .tools import *
|
||||||
from .graph import *
|
from .graph import *
|
||||||
@@ -29,6 +31,7 @@ name2sympy = {
|
|||||||
"min": SympyMin,
|
"min": SympyMin,
|
||||||
"maxabs": SympyMaxabs,
|
"maxabs": SympyMaxabs,
|
||||||
"mean": SympyMean,
|
"mean": SympyMean,
|
||||||
|
"clip": SympyClip,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +48,9 @@ def convert_to_sympy(func: Union[str, callable]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
FUNCS_MODULE = {}
|
SYMPY_FUNCS_MODULE_NP = {}
|
||||||
|
SYMPY_FUNCS_MODULE_JNP = {}
|
||||||
for cls in name2sympy.values():
|
for cls in name2sympy.values():
|
||||||
if hasattr(cls, "numerical_eval"):
|
if hasattr(cls, "numerical_eval"):
|
||||||
FUNCS_MODULE[cls.__name__] = cls.numerical_eval
|
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||||
|
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import Union
|
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -13,8 +12,8 @@ class SympyClip(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(val, min_val, max_val):
|
def numerical_eval(val, min_val, max_val, backend=np):
|
||||||
return np.clip(val, min_val, max_val)
|
return backend.clip(val, min_val, max_val)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||||
@@ -32,9 +31,9 @@ class SympySigmoid(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
z = np.clip(5 * z, -10, 10)
|
z = backend.clip(5 * z, -10, 10)
|
||||||
return 1 / (1 + np.exp(-z))
|
return 1 / (1 + backend.exp(-z))
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"sigmoid({self.args[0]})"
|
return f"sigmoid({self.args[0]})"
|
||||||
@@ -49,8 +48,8 @@ class SympyTanh(sp.Function):
|
|||||||
return sp.tanh(0.6 * z)
|
return sp.tanh(0.6 * z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
return np.tanh(0.6 * z)
|
return backend.tanh(0.6 * z)
|
||||||
|
|
||||||
|
|
||||||
class SympySin(sp.Function):
|
class SympySin(sp.Function):
|
||||||
@@ -59,8 +58,8 @@ class SympySin(sp.Function):
|
|||||||
return sp.sin(z)
|
return sp.sin(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
return np.sin(z)
|
return backend.sin(z)
|
||||||
|
|
||||||
|
|
||||||
class SympyRelu(sp.Function):
|
class SympyRelu(sp.Function):
|
||||||
@@ -71,8 +70,8 @@ class SympyRelu(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
return np.maximum(z, 0)
|
return backend.maximum(z, 0)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"relu({self.args[0]})"
|
return f"relu({self.args[0]})"
|
||||||
@@ -90,9 +89,9 @@ class SympyLelu(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
leaky = 0.005
|
leaky = 0.005
|
||||||
return np.maximum(z, leaky * z)
|
return backend.maximum(z, leaky * z)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"lelu({self.args[0]})"
|
return f"lelu({self.args[0]})"
|
||||||
@@ -107,7 +106,7 @@ class SympyIdentity(sp.Function):
|
|||||||
return z
|
return z
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
@@ -117,8 +116,8 @@ class SympyClamped(sp.Function):
|
|||||||
return SympyClip(z, -1, 1)
|
return SympyClip(z, -1, 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
return np.clip(z, -1, 1)
|
return backend.clip(z, -1, 1)
|
||||||
|
|
||||||
|
|
||||||
class SympyInv(sp.Function):
|
class SympyInv(sp.Function):
|
||||||
@@ -130,8 +129,8 @@ class SympyInv(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
z = np.maximum(z, 1e-7)
|
z = backend.maximum(z, 1e-7)
|
||||||
return 1 / z
|
return 1 / z
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
@@ -150,9 +149,9 @@ class SympyLog(sp.Function):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z):
|
def numerical_eval(z, backend=np):
|
||||||
z = np.maximum(z, 1e-7)
|
z = backend.maximum(z, 1e-7)
|
||||||
return np.log(z)
|
return backend.log(z)
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"log({self.args[0]})"
|
return f"log({self.args[0]})"
|
||||||
@@ -169,11 +168,6 @@ class SympyExp(sp.Function):
|
|||||||
return sp.exp(z)
|
return sp.exp(z)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z):
|
|
||||||
z = np.clip(z, -10, 10)
|
|
||||||
return np.exp(z)
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"exp({self.args[0]})"
|
return f"exp({self.args[0]})"
|
||||||
|
|
||||||
@@ -185,7 +179,3 @@ class SympyAbs(sp.Function):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.Abs(z)
|
return sp.Abs(z)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(z):
|
|
||||||
return np.abs(z)
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
||||||
|
|
||||||
@@ -51,15 +52,6 @@ class SympyMedian(sp.Function):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numerical_eval(args):
|
|
||||||
sorted_args = sorted(args)
|
|
||||||
n = len(sorted_args)
|
|
||||||
if n % 2 == 1:
|
|
||||||
return sorted_args[n // 2]
|
|
||||||
else:
|
|
||||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
|
||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
return f"median({', '.join(map(str, self.args))})"
|
return f"median({', '.join(map(str, self.args))})"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user