add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
@@ -47,3 +47,12 @@ class BaseNodeGene(BaseGene):
|
||||
return "{}(idx={:<{idx_width}})".format(
|
||||
self.__class__.__name__, idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx = node[0]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False, precision=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -45,12 +55,12 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_indices = np.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = jnp.arange(len(activation_options))
|
||||
self.activation_indices = np.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
@@ -145,5 +155,38 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"res": float(res),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
res = node_dict["res"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
res = round(res, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z * res
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -2,7 +2,16 @@ from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from utils import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
@@ -121,3 +130,33 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, agg, act = node
|
||||
return {
|
||||
"idx": int(idx),
|
||||
"bias": float(bias),
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": self.activation_options[int(act)].__name__,
|
||||
}
|
||||
|
||||
def sympy_func(
|
||||
self, state, node_dict, inputs, is_output_node=False, precision=None
|
||||
):
|
||||
|
||||
bias = node_dict["bias"]
|
||||
agg = node_dict["agg"]
|
||||
act = node_dict["act"]
|
||||
|
||||
if precision is not None:
|
||||
bias = round(bias, precision)
|
||||
|
||||
z = convert_to_sympy(agg)(inputs)
|
||||
z = bias + z
|
||||
|
||||
if is_output_node:
|
||||
return z
|
||||
else:
|
||||
z = convert_to_sympy(act)(z)
|
||||
|
||||
return z
|
||||
|
||||
@@ -25,8 +25,3 @@ class KANNode(BaseNodeGene):
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return Agg.sum(inputs)
|
||||
|
||||
def repr(self, state, node, precision=2):
|
||||
idx = node[0]
|
||||
idx = int(idx)
|
||||
return "{}(idx: {})".format(self.__class__.__name__, idx)
|
||||
|
||||
Reference in New Issue
Block a user