fix bugs
This commit is contained in:
@@ -4,7 +4,7 @@ import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .substrate import *
|
||||
from tensorneat.common import State, Act, Agg
|
||||
from tensorneat.common import State, ACT, AGG
|
||||
from tensorneat.algorithm import BaseAlgorithm, NEAT
|
||||
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
|
||||
|
||||
@@ -16,10 +16,10 @@ class HyperNEAT(BaseAlgorithm):
|
||||
neat: NEAT,
|
||||
weight_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
aggregation: Callable = Agg.sum,
|
||||
activation: Callable = Act.sigmoid,
|
||||
aggregation: Callable = AGG.sum,
|
||||
activation: Callable = ACT.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.standard_sigmoid,
|
||||
output_transform: Callable = ACT.standard_sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
@@ -102,8 +102,8 @@ class HyperNEAT(BaseAlgorithm):
|
||||
class HyperNEATNode(BaseNode):
|
||||
def __init__(
|
||||
self,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=AGG.sum,
|
||||
activation=ACT.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregation = aggregation
|
||||
|
||||
Reference in New Issue
Block a user