fix bugs
This commit is contained in:
@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
ACT,
|
||||
AGG,
|
||||
apply_activation,
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
@@ -39,10 +39,10 @@ class DefaultNode(BaseNode):
|
||||
response_lower_bound: float = -5,
|
||||
response_upper_bound: float = 5,
|
||||
aggregation_default: Optional[Callable] = None,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: Optional[Callable] = None,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -89,7 +89,7 @@ class DefaultNode(BaseNode):
|
||||
agg = self.aggregation_default
|
||||
act = self.activation_default
|
||||
|
||||
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
|
||||
return jnp.array([bias, res, agg, act]) # activation=-1 means ACT.identity
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
@@ -148,12 +148,12 @@ class DefaultNode(BaseNode):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, res, agg, act = attrs
|
||||
|
||||
z = agg_func(agg, inputs, self.aggregation_options)
|
||||
z = apply_aggregation(agg, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
|
||||
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
@@ -168,7 +168,7 @@ class DefaultNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
@@ -193,7 +193,7 @@ class DefaultNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user