disable activation in the output node of network;

we recommend to use output_transform;
change hyperparameters (strong) in XOR example;
This commit is contained in:
wls2002
2024-05-22 11:09:25 +08:00
parent bb80f12640
commit 1fe5d5fca2
7 changed files with 59 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
)
def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
def forward(self, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']