Merge branch 'main' into advance
This commit is contained in:
@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
(node1[4] != node2[4])
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
bias, res, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: z,
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
Reference in New Issue
Block a user