fix bugs
This commit is contained in:
@@ -4,7 +4,7 @@ import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .substrate import *
|
||||
from tensorneat.common import State, Act, Agg
|
||||
from tensorneat.common import State, ACT, AGG
|
||||
from tensorneat.algorithm import BaseAlgorithm, NEAT
|
||||
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
|
||||
|
||||
@@ -16,10 +16,10 @@ class HyperNEAT(BaseAlgorithm):
|
||||
neat: NEAT,
|
||||
weight_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
aggregation: Callable = Agg.sum,
|
||||
activation: Callable = Act.sigmoid,
|
||||
aggregation: Callable = AGG.sum,
|
||||
activation: Callable = ACT.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.standard_sigmoid,
|
||||
output_transform: Callable = ACT.standard_sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
@@ -102,8 +102,8 @@ class HyperNEAT(BaseAlgorithm):
|
||||
class HyperNEATNode(BaseNode):
|
||||
def __init__(
|
||||
self,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=AGG.sum,
|
||||
activation=ACT.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregation = aggregation
|
||||
|
||||
@@ -1,56 +1,6 @@
|
||||
from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
|
||||
from .activation.act_jnp import Act, ACT_ALL, act_func
|
||||
from .aggregation.agg_sympy import *
|
||||
from .activation.act_sympy import *
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"standard_sigmoid": SympyStandardSigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"standard_tanh": SympyStandardTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
"identity": SympyIdentity,
|
||||
"inv": SympyInv,
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"sum": SympySum,
|
||||
"product": SympyProduct,
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
"square": SympySquare,
|
||||
}
|
||||
|
||||
|
||||
def convert_to_sympy(func: Union[str, Callable]):
|
||||
if isinstance(func, str):
|
||||
name = func
|
||||
else:
|
||||
name = func.__name__
|
||||
if name in name2sympy:
|
||||
return name2sympy[name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not convert to sympy! Function {name} not found in name2sympy"
|
||||
)
|
||||
|
||||
|
||||
SYMPY_FUNCS_MODULE_NP = {}
|
||||
SYMPY_FUNCS_MODULE_JNP = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation
|
||||
|
||||
@@ -4,10 +4,10 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
ACT,
|
||||
AGG,
|
||||
apply_activation,
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
@@ -34,10 +34,10 @@ class BiasNode(BaseNode):
|
||||
bias_lower_bound: float = -5,
|
||||
bias_upper_bound: float = 5,
|
||||
aggregation_default: Optional[Callable] = None,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: Optional[Callable] = None,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -73,7 +73,7 @@ class BiasNode(BaseNode):
|
||||
def new_identity_attrs(self, state):
|
||||
return jnp.array(
|
||||
[0, self.aggregation_default, -1]
|
||||
) # activation=-1 means Act.identity
|
||||
) # activation=-1 means ACT.identity
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
@@ -115,12 +115,12 @@ class BiasNode(BaseNode):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, agg, act = attrs
|
||||
|
||||
z = agg_func(agg, inputs, self.aggregation_options)
|
||||
z = apply_aggregation(agg, inputs, self.aggregation_options)
|
||||
z = bias + z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
|
||||
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
@@ -134,7 +134,7 @@ class BiasNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
@@ -158,7 +158,7 @@ class BiasNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
ACT,
|
||||
AGG,
|
||||
apply_activation,
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
@@ -39,10 +39,10 @@ class DefaultNode(BaseNode):
|
||||
response_lower_bound: float = -5,
|
||||
response_upper_bound: float = 5,
|
||||
aggregation_default: Optional[Callable] = None,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: Optional[Callable] = None,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -89,7 +89,7 @@ class DefaultNode(BaseNode):
|
||||
agg = self.aggregation_default
|
||||
act = self.activation_default
|
||||
|
||||
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
|
||||
return jnp.array([bias, res, agg, act]) # activation=-1 means ACT.identity
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
@@ -148,12 +148,12 @@ class DefaultNode(BaseNode):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, res, agg, act = attrs
|
||||
|
||||
z = agg_func(agg, inputs, self.aggregation_options)
|
||||
z = apply_aggregation(agg, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
|
||||
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
@@ -168,7 +168,7 @@ class DefaultNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
@@ -193,7 +193,7 @@ class DefaultNode(BaseNode):
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
act_func = ACT.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user