Merge branch 'main' into advance

This commit is contained in:
WLS2002
2024-05-24 19:42:03 +08:00
committed by GitHub
17 changed files with 156 additions and 82 deletions

View File

@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
# the last output node should not be activated
z = jax.lax.cond(
is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z