This commit is contained in:
root
2024-07-12 02:25:57 +08:00
parent 3194678a15
commit 5fc63fdaf1
28 changed files with 351 additions and 142 deletions

View File

@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
Act,
Agg,
act_func,
agg_func,
ACT,
AGG,
apply_activation,
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
@@ -39,10 +39,10 @@ class DefaultNode(BaseNode):
response_lower_bound: float = -5,
response_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
aggregation_replace_rate: float = 0.1,
activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
activation_replace_rate: float = 0.1,
):
super().__init__()
@@ -89,7 +89,7 @@ class DefaultNode(BaseNode):
agg = self.aggregation_default
act = self.activation_default
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
return jnp.array([bias, res, agg, act]) # activation=-1 means ACT.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
@@ -148,12 +148,12 @@ class DefaultNode(BaseNode):
def forward(self, state, attrs, inputs, is_output_node=False):
bias, res, agg, act = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = apply_aggregation(agg, inputs, self.aggregation_options)
z = bias + res * z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
)
return z
@@ -168,7 +168,7 @@ class DefaultNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
@@ -193,7 +193,7 @@ class DefaultNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]
return {