add sympy support; which can transfer your network into sympy expression;

add visualize in genome;
add related tests.
This commit is contained in:
wls2002
2024-06-12 21:36:35 +08:00
parent dfc8f9198e
commit b3e442c688
29 changed files with 6196 additions and 168 deletions

View File

@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width
)
def to_dict(self, state, node):
idx = node[0]
return {
"idx": int(idx),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
raise NotImplementedError

View File

@@ -1,8 +1,18 @@
from typing import Tuple
import numpy as np
import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, res, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"res": float(res),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
res = node_dict["res"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
res = round(res, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z * res
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -2,7 +2,16 @@ from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from utils import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
float_width=precision + 3,
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, agg, act = node
return {
"idx": int(idx),
"bias": float(bias),
"agg": self.aggregation_options[int(agg)].__name__,
"act": self.activation_options[int(act)].__name__,
}
def sympy_func(
self, state, node_dict, inputs, is_output_node=False, precision=None
):
bias = node_dict["bias"]
agg = node_dict["agg"]
act = node_dict["act"]
if precision is not None:
bias = round(bias, precision)
z = convert_to_sympy(agg)(inputs)
z = bias + z
if is_output_node:
return z
else:
z = convert_to_sympy(act)(z)
return z

View File

@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs)
def repr(self, state, node, precision=2):
idx = node[0]
idx = int(idx)
return "{}(idx: {})".format(self.__class__.__name__, idx)