This commit is contained in:
root
2024-07-12 02:25:57 +08:00
parent 3194678a15
commit 5fc63fdaf1
28 changed files with 351 additions and 142 deletions

View File

@@ -4,7 +4,7 @@ import jax
from jax import vmap, numpy as jnp
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.common import State, ACT, AGG
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
@@ -16,10 +16,10 @@ class HyperNEAT(BaseAlgorithm):
neat: NEAT,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
aggregation: Callable = AGG.sum,
activation: Callable = ACT.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.standard_sigmoid,
output_transform: Callable = ACT.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
@@ -102,8 +102,8 @@ class HyperNEAT(BaseAlgorithm):
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
activation=Act.sigmoid,
aggregation=AGG.sum,
activation=ACT.sigmoid,
):
super().__init__()
self.aggregation = aggregation

View File

@@ -1,56 +1,6 @@
from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
from .state import State
from .stateful_class import StatefulBaseClass
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Callable, Union
name2sympy = {
"sigmoid": SympySigmoid,
"standard_sigmoid": SympyStandardSigmoid,
"tanh": SympyTanh,
"standard_tanh": SympyStandardTanh,
"sin": SympySin,
"relu": SympyRelu,
"lelu": SympyLelu,
"identity": SympyIdentity,
"inv": SympyInv,
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"sum": SympySum,
"product": SympyProduct,
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"mean": SympyMean,
"clip": SympyClip,
"square": SympySquare,
}
def convert_to_sympy(func: Union[str, Callable]):
if isinstance(func, str):
name = func
else:
name = func.__name__
if name in name2sympy:
return name2sympy[name]
else:
raise ValueError(
f"Can not convert to sympy! Function {name} not found in name2sympy"
)
SYMPY_FUNCS_MODULE_NP = {}
SYMPY_FUNCS_MODULE_JNP = {}
for cls in name2sympy.values():
if hasattr(cls, "numerical_eval"):
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
from .functions import ACT, AGG, apply_activation, apply_aggregation

View File

@@ -4,10 +4,10 @@ import numpy as np
import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
Act,
Agg,
act_func,
agg_func,
ACT,
AGG,
apply_activation,
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
@@ -34,10 +34,10 @@ class BiasNode(BaseNode):
bias_lower_bound: float = -5,
bias_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
aggregation_replace_rate: float = 0.1,
activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
activation_replace_rate: float = 0.1,
):
super().__init__()
@@ -73,7 +73,7 @@ class BiasNode(BaseNode):
def new_identity_attrs(self, state):
return jnp.array(
[0, self.aggregation_default, -1]
) # activation=-1 means Act.identity
) # activation=-1 means ACT.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
@@ -115,12 +115,12 @@ class BiasNode(BaseNode):
def forward(self, state, attrs, inputs, is_output_node=False):
bias, agg, act = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = apply_aggregation(agg, inputs, self.aggregation_options)
z = bias + z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
)
return z
@@ -134,7 +134,7 @@ class BiasNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
@@ -158,7 +158,7 @@ class BiasNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]

View File

@@ -5,10 +5,10 @@ import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
Act,
Agg,
act_func,
agg_func,
ACT,
AGG,
apply_activation,
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
@@ -39,10 +39,10 @@ class DefaultNode(BaseNode):
response_lower_bound: float = -5,
response_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_options: Union[Callable, Sequence[Callable]] = AGG.sum,
aggregation_replace_rate: float = 0.1,
activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_options: Union[Callable, Sequence[Callable]] = ACT.sigmoid,
activation_replace_rate: float = 0.1,
):
super().__init__()
@@ -89,7 +89,7 @@ class DefaultNode(BaseNode):
agg = self.aggregation_default
act = self.activation_default
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
return jnp.array([bias, res, agg, act]) # activation=-1 means ACT.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
@@ -148,12 +148,12 @@ class DefaultNode(BaseNode):
def forward(self, state, attrs, inputs, is_output_node=False):
bias, res, agg, act = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = apply_aggregation(agg, inputs, self.aggregation_options)
z = bias + res * z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
is_output_node, lambda: z, lambda: apply_activation(act, z, self.activation_options)
)
return z
@@ -168,7 +168,7 @@ class DefaultNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
@@ -193,7 +193,7 @@ class DefaultNode(BaseNode):
act = int(act)
if act == -1:
act_func = Act.identity
act_func = ACT.identity
else:
act_func = self.activation_options[act]
return {