adjust default parameter; successful run recurrent-xor example

This commit is contained in:
root
2024-07-11 10:57:43 +08:00
parent 4a631f9464
commit 9bad577d89
18 changed files with 118 additions and 136 deletions

View File

@@ -1,43 +1,30 @@
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, DefaultNodeGene, DefaultMutation
from tensorneat.genome import DefaultGenome
from tensorneat.problem.func_fit import XOR3d
from tensorneat.common import Act, Agg
from tensorneat.common import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
pop_size=10000,
species_size=20,
compatibility_threshold=2,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
init_hidden_layers=(),
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=Act.tanh,
aggregation_default=Agg.sum,
aggregation_options=Agg.sum,
),
output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,
conn_add=0.1,
node_delete=0,
conn_delete=0,
),
output_transform=Act.standard_sigmoid,
),
),
problem=XOR3d(),
generation_limit=500,
fitness_target=-1e-8,
fitness_target=-1e-6, # float32 precision
seed=42,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result

View File

@@ -1,46 +1,31 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL, Act
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import RecurrentGenome
from tensorneat.problem.func_fit import XOR3d
from tensorneat.common import Act, Agg
if __name__ == "__main__":
pipeline = Pipeline(
seed=0,
algorithm=NEAT(
species=DefaultSpecies(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=RecurrentGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=NodeGeneWithoutResponse(
activation_options=ACT_ALL, activation_replace_rate=0.2
),
output_transform=Act.sigmoid,
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
init_hidden_layers=(),
output_transform=Act.standard_sigmoid,
activate_time=10,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8,
generation_limit=500,
fitness_target=-1e-6, # float32 precision
seed=42,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
@@ -18,10 +20,10 @@ class NEAT(BaseAlgorithm):
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
survival_threshold: float = 0.1,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
species_fitness_func: callable = jnp.max,
compatibility_threshold: float = 2.0,
species_fitness_func: Callable = jnp.max,
):
self.genome = genome
self.pop_size = pop_size

View File

@@ -9,7 +9,7 @@ from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import *
from .activation.act_sympy import *
from typing import Union
from typing import Callable, Union
name2sympy = {
"sigmoid": SympySigmoid,
@@ -34,7 +34,7 @@ name2sympy = {
}
def convert_to_sympy(func: Union[str, callable]):
def convert_to_sympy(func: Union[str, Callable]):
if isinstance(func, str):
name = func
else:

View File

@@ -52,7 +52,6 @@ class Act:
@staticmethod
def identity(z):
z = jnp.clip(z, -sigma_3, sigma_3)
return z
@staticmethod

View File

@@ -54,13 +54,6 @@ class SympyStandardSigmoid(sp.Function):
def eval(cls, z):
return SympySigmoid_(5 * z / sigma_3)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# z = 1 / (1 + backend.exp(-z))
#
# return z # (0, 1)
class SympyTanh(sp.Function):
@classmethod
@@ -68,11 +61,6 @@ class SympyTanh(sp.Function):
z = 5 * z / sigma_3
return sp.tanh(z) * sigma_3
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function):
@classmethod
@@ -80,11 +68,6 @@ class SympyStandardTanh(sp.Function):
z = 5 * z / sigma_3
return sp.tanh(z)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function):
@classmethod
@@ -143,14 +126,7 @@ class SympyLelu(sp.Function):
class SympyIdentity(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -sigma_3, sigma_3)
return z
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.clip(z, -sigma_3, sigma_3)
class SympyInv(sp.Function):

View File

@@ -3,7 +3,7 @@ from typing import Callable, Sequence
import numpy as np
import jax
from jax import vmap, numpy as jnp
from .gene import BaseNodeGene, BaseConnGene
from .gene import BaseNode, BaseConn
from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import (
State,
@@ -22,8 +22,8 @@ class BaseGenome(StatefulBaseClass):
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene,
conn_gene: BaseConnGene,
node_gene: BaseNode,
conn_gene: BaseConn,
mutation: BaseMutation,
crossover: BaseCrossover,
distance: BaseDistance,
@@ -92,7 +92,6 @@ class BaseGenome(StatefulBaseClass):
self.output_idx = np.array(layer_indices[-1])
self.all_init_nodes = np.array(all_init_nodes)
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
print(self.output_idx)
def setup(self, state=State()):
state = self.node_gene.setup(state)

View File

@@ -6,7 +6,7 @@ import numpy as np
import sympy as sp
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -31,8 +31,8 @@ class DefaultGenome(BaseGenome):
num_outputs: int,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),

View File

@@ -1,2 +1,2 @@
from .base import BaseConnGene
from .default import DefaultConnGene
from .base import BaseConn
from .default import DefaultConn

View File

@@ -1,8 +1,7 @@
import jax
from .. import BaseGene
from ..base import BaseGene
class BaseConnGene(BaseGene):
class BaseConn(BaseGene):
"Base class for connection genes."
fixed_attrs = ["input_index", "output_index"]

View File

@@ -2,10 +2,10 @@ import jax.numpy as jnp
import jax.random
import sympy as sp
from tensorneat.common import mutate_float
from .base import BaseConnGene
from .base import BaseConn
class DefaultConnGene(BaseConnGene):
class DefaultConn(BaseConn):
"Default connection gene, with the same behavior as in NEAT-python."
custom_attrs = ["weight"]
@@ -14,9 +14,9 @@ class DefaultConnGene(BaseConnGene):
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
weight_mutate_power: float = 0.15,
weight_mutate_rate: float = 0.2,
weight_replace_rate: float = 0.015,
):
super().__init__()
self.weight_init_mean = weight_init_mean

View File

@@ -1,3 +1,3 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene
from .base import BaseNode
from .default import DefaultNode
from .bias import BiasNode

View File

@@ -2,7 +2,7 @@ import jax, jax.numpy as jnp
from .. import BaseGene
class BaseNodeGene(BaseGene):
class BaseNode(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]

View File

@@ -1,5 +1,6 @@
from typing import Tuple
from typing import Union, Sequence, Callable, Optional
import numpy as np
import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
@@ -12,10 +13,10 @@ from tensorneat.common import (
convert_to_sympy,
)
from . import BaseNodeGene
from . import BaseNode
class BiasNode(BaseNodeGene):
class BiasNode(BaseNode):
"""
Default node gene, with the same behavior as in NEAT-python.
The attribute response is removed.
@@ -27,31 +28,46 @@ class BiasNode(BaseNodeGene):
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
bias_mutate_power: float = 0.15,
bias_mutate_rate: float = 0.2,
bias_replace_rate: float = 0.015,
bias_lower_bound: float = -5,
bias_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_replace_rate: float = 0.1,
):
super().__init__()
if isinstance(aggregation_options, Callable):
aggregation_options = [aggregation_options]
if isinstance(activation_options, Callable):
activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.bias_lower_bound = bias_lower_bound
self.bias_upper_bound = bias_upper_bound
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
@@ -62,6 +78,7 @@ class BiasNode(BaseNodeGene):
def new_random_attrs(self, state, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
agg = jax.random.choice(k2, self.aggregation_indices)
act = jax.random.choice(k3, self.activation_indices)
@@ -80,7 +97,7 @@ class BiasNode(BaseNodeGene):
self.bias_mutate_rate,
self.bias_replace_rate,
)
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
agg = mutate_int(
k2, agg, self.aggregation_indices, self.aggregation_replace_rate
)

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Union, Sequence, Callable
from typing import Optional, Union, Sequence, Callable
import numpy as np
import jax, jax.numpy as jnp
@@ -14,10 +14,10 @@ from tensorneat.common import (
convert_to_sympy,
)
from . import BaseNodeGene
from .base import BaseNode
class DefaultNodeGene(BaseNodeGene):
class DefaultNode(BaseNode):
"Default node gene, with the same behavior as in NEAT-python."
custom_attrs = ["bias", "response", "aggregation", "activation"]
@@ -26,18 +26,22 @@ class DefaultNodeGene(BaseNodeGene):
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
bias_mutate_power: float = 0.15,
bias_mutate_rate: float = 0.2,
bias_replace_rate: float = 0.015,
bias_lower_bound: float = -5,
bias_upper_bound: float = 5,
response_init_mean: float = 1.0,
response_init_std: float = 0.0,
response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1,
aggregation_default: Callable = Agg.sum,
response_mutate_power: float = 0.15,
response_mutate_rate: float = 0.2,
response_replace_rate: float = 0.015,
response_lower_bound: float = -5,
response_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_replace_rate: float = 0.1,
activation_default: Callable = Act.sigmoid,
activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_replace_rate: float = 0.1,
):
@@ -48,17 +52,26 @@ class DefaultNodeGene(BaseNodeGene):
if isinstance(activation_options, Callable):
activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.bias_lower_bound = bias_lower_bound
self.bias_upper_bound = bias_upper_bound
self.response_init_mean = response_init_mean
self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate
self.reponse_lower_bound = response_lower_bound
self.response_upper_bound = response_upper_bound
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
@@ -71,16 +84,21 @@ class DefaultNodeGene(BaseNodeGene):
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
return jnp.array(
[0, 1, self.aggregation_default, -1]
) # activation=-1 means Act.identity
bias = 0
res = 1
agg = self.aggregation_default
act = self.activation_default
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
res = (
jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean
)
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
agg = jax.random.choice(k3, self.aggregation_indices)
act = jax.random.choice(k4, self.activation_indices)
@@ -98,7 +116,7 @@ class DefaultNodeGene(BaseNodeGene):
self.bias_mutate_rate,
self.bias_replace_rate,
)
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
res = mutate_float(
k2,
res,
@@ -108,7 +126,7 @@ class DefaultNodeGene(BaseNodeGene):
self.response_mutate_rate,
self.response_replace_rate,
)
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
agg = mutate_int(
k4, agg, self.aggregation_indices, self.aggregation_replace_rate
)

View File

@@ -23,9 +23,9 @@ from ...utils import (
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.2,
conn_add: float = 0.1,
conn_delete: float = 0,
node_add: float = 0.2,
node_add: float = 0.1,
node_delete: float = 0,
):
self.conn_add = conn_add

View File

@@ -3,7 +3,7 @@ from jax import vmap, numpy as jnp
from .utils import unflatten_conns
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -20,8 +20,8 @@ class RecurrentGenome(BaseGenome):
num_outputs: int,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),