This commit is contained in:
root
2024-07-12 02:25:57 +08:00
parent 3194678a15
commit 5fc63fdaf1
28 changed files with 351 additions and 142 deletions

View File

@@ -4,7 +4,7 @@ import jax
from jax import vmap, numpy as jnp
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.common import State, ACT, AGG
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
@@ -16,10 +16,10 @@ class HyperNEAT(BaseAlgorithm):
neat: NEAT,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
aggregation: Callable = AGG.sum,
activation: Callable = ACT.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.standard_sigmoid,
output_transform: Callable = ACT.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
@@ -102,8 +102,8 @@ class HyperNEAT(BaseAlgorithm):
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
activation=Act.sigmoid,
aggregation=AGG.sum,
activation=ACT.sigmoid,
):
super().__init__()
self.aggregation = aggregation