disable activation in the output node of network;
we recommend to use output_transform; change hyperparameters (strong) in XOR example;
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs))
|
||||
|
||||
)
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
Reference in New Issue
Block a user