adjust default parameter; successful run recurrent-xor example

This commit is contained in:
root
2024-07-11 10:57:43 +08:00
parent 4a631f9464
commit 9bad577d89
18 changed files with 118 additions and 136 deletions

View File

@@ -1,43 +1,30 @@
from tensorneat.pipeline import Pipeline from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, DefaultNodeGene, DefaultMutation from tensorneat.genome import DefaultGenome
from tensorneat.problem.func_fit import XOR3d from tensorneat.problem.func_fit import XOR3d
from tensorneat.common import Act, Agg from tensorneat.common import Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
pop_size=10000, pop_size=10000,
species_size=20, species_size=20,
compatibility_threshold=2,
survival_threshold=0.01, survival_threshold=0.01,
genome=DefaultGenome( genome=DefaultGenome(
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
init_hidden_layers=(), init_hidden_layers=(),
node_gene=DefaultNodeGene( output_transform=Act.standard_sigmoid,
activation_default=Act.tanh,
activation_options=Act.tanh,
aggregation_default=Agg.sum,
aggregation_options=Agg.sum,
),
output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,
conn_add=0.1,
node_delete=0,
conn_delete=0,
),
), ),
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=500, generation_limit=500,
fitness_target=-1e-8, fitness_target=-1e-6, # float32 precision
seed=42,
) )
# initialize state # initialize state
state = pipeline.setup() state = pipeline.setup()
# print(state)
# run until terminate # run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result # show result

View File

@@ -1,46 +1,31 @@
from pipeline import Pipeline from tensorneat.pipeline import Pipeline
from algorithm.neat import * from tensorneat.algorithm.neat import NEAT
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse from tensorneat.genome import RecurrentGenome
from tensorneat.problem.func_fit import XOR3d
from problem.func_fit import XOR3d from tensorneat.common import Act, Agg
from utils.activation import ACT_ALL, Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
seed=0,
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( pop_size=10000,
genome=RecurrentGenome( species_size=20,
num_inputs=3, survival_threshold=0.01,
num_outputs=1, genome=RecurrentGenome(
max_nodes=50, num_inputs=3,
max_conns=100, num_outputs=1,
activate_time=5, init_hidden_layers=(),
node_gene=NodeGeneWithoutResponse( output_transform=Act.standard_sigmoid,
activation_options=ACT_ALL, activation_replace_rate=0.2 activate_time=10,
),
output_transform=Act.sigmoid,
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
), ),
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=500,
fitness_target=-1e-8, fitness_target=-1e-6, # float32 precision
seed=42,
) )
# initialize state # initialize state
state = pipeline.setup() state = pipeline.setup()
# print(state)
# run until terminate # run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result # show result

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax import jax
from jax import vmap, numpy as jnp from jax import vmap, numpy as jnp
import numpy as np import numpy as np
@@ -18,10 +20,10 @@ class NEAT(BaseAlgorithm):
species_elitism: int = 2, species_elitism: int = 2,
spawn_number_change_rate: float = 0.5, spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2, genome_elitism: int = 2,
survival_threshold: float = 0.2, survival_threshold: float = 0.1,
min_species_size: int = 1, min_species_size: int = 1,
compatibility_threshold: float = 3.0, compatibility_threshold: float = 2.0,
species_fitness_func: callable = jnp.max, species_fitness_func: Callable = jnp.max,
): ):
self.genome = genome self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size

View File

@@ -9,7 +9,7 @@ from .activation.act_jnp import Act, ACT_ALL, act_func
from .aggregation.agg_sympy import * from .aggregation.agg_sympy import *
from .activation.act_sympy import * from .activation.act_sympy import *
from typing import Union from typing import Callable, Union
name2sympy = { name2sympy = {
"sigmoid": SympySigmoid, "sigmoid": SympySigmoid,
@@ -34,7 +34,7 @@ name2sympy = {
} }
def convert_to_sympy(func: Union[str, callable]): def convert_to_sympy(func: Union[str, Callable]):
if isinstance(func, str): if isinstance(func, str):
name = func name = func
else: else:

View File

@@ -31,7 +31,7 @@ class Act:
@staticmethod @staticmethod
def standard_tanh(z): def standard_tanh(z):
z =5 * z / sigma_3 z = 5 * z / sigma_3
return jnp.tanh(z) # (-1, 1) return jnp.tanh(z) # (-1, 1)
@staticmethod @staticmethod
@@ -52,7 +52,6 @@ class Act:
@staticmethod @staticmethod
def identity(z): def identity(z):
z = jnp.clip(z, -sigma_3, sigma_3)
return z return z
@staticmethod @staticmethod

View File

@@ -54,13 +54,6 @@ class SympyStandardSigmoid(sp.Function):
def eval(cls, z): def eval(cls, z):
return SympySigmoid_(5 * z / sigma_3) return SympySigmoid_(5 * z / sigma_3)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# z = 1 / (1 + backend.exp(-z))
#
# return z # (0, 1)
class SympyTanh(sp.Function): class SympyTanh(sp.Function):
@classmethod @classmethod
@@ -68,11 +61,6 @@ class SympyTanh(sp.Function):
z = 5 * z / sigma_3 z = 5 * z / sigma_3
return sp.tanh(z) * sigma_3 return sp.tanh(z) * sigma_3
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function): class SympyStandardTanh(sp.Function):
@classmethod @classmethod
@@ -80,11 +68,6 @@ class SympyStandardTanh(sp.Function):
z = 5 * z / sigma_3 z = 5 * z / sigma_3
return sp.tanh(z) return sp.tanh(z)
# @staticmethod
# def numerical_eval(z, backend=np):
# z = backend.clip(5 * z / sigma_3, -5, 5)
# return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function): class SympySin(sp.Function):
@classmethod @classmethod
@@ -143,14 +126,7 @@ class SympyLelu(sp.Function):
class SympyIdentity(sp.Function): class SympyIdentity(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: return z
z = SympyClip(z, -sigma_3, sigma_3)
return z
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.clip(z, -sigma_3, sigma_3)
class SympyInv(sp.Function): class SympyInv(sp.Function):

View File

@@ -3,7 +3,7 @@ from typing import Callable, Sequence
import numpy as np import numpy as np
import jax import jax
from jax import vmap, numpy as jnp from jax import vmap, numpy as jnp
from .gene import BaseNodeGene, BaseConnGene from .gene import BaseNode, BaseConn
from .operations import BaseMutation, BaseCrossover, BaseDistance from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import ( from tensorneat.common import (
State, State,
@@ -22,8 +22,8 @@ class BaseGenome(StatefulBaseClass):
num_outputs: int, num_outputs: int,
max_nodes: int, max_nodes: int,
max_conns: int, max_conns: int,
node_gene: BaseNodeGene, node_gene: BaseNode,
conn_gene: BaseConnGene, conn_gene: BaseConn,
mutation: BaseMutation, mutation: BaseMutation,
crossover: BaseCrossover, crossover: BaseCrossover,
distance: BaseDistance, distance: BaseDistance,
@@ -92,7 +92,6 @@ class BaseGenome(StatefulBaseClass):
self.output_idx = np.array(layer_indices[-1]) self.output_idx = np.array(layer_indices[-1])
self.all_init_nodes = np.array(all_init_nodes) self.all_init_nodes = np.array(all_init_nodes)
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx] self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
print(self.output_idx)
def setup(self, state=State()): def setup(self, state=State()):
state = self.node_gene.setup(state) state = self.node_gene.setup(state)

View File

@@ -6,7 +6,7 @@ import numpy as np
import sympy as sp import sympy as sp
from .base import BaseGenome from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -31,8 +31,8 @@ class DefaultGenome(BaseGenome):
num_outputs: int, num_outputs: int,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=DefaultNodeGene(), node_gene=DefaultNode(),
conn_gene=DefaultConnGene(), conn_gene=DefaultConn(),
mutation=DefaultMutation(), mutation=DefaultMutation(),
crossover=DefaultCrossover(), crossover=DefaultCrossover(),
distance=DefaultDistance(), distance=DefaultDistance(),

View File

@@ -1,2 +1,2 @@
from .base import BaseConnGene from .base import BaseConn
from .default import DefaultConnGene from .default import DefaultConn

View File

@@ -1,8 +1,7 @@
import jax from ..base import BaseGene
from .. import BaseGene
class BaseConnGene(BaseGene): class BaseConn(BaseGene):
"Base class for connection genes." "Base class for connection genes."
fixed_attrs = ["input_index", "output_index"] fixed_attrs = ["input_index", "output_index"]

View File

@@ -2,10 +2,10 @@ import jax.numpy as jnp
import jax.random import jax.random
import sympy as sp import sympy as sp
from tensorneat.common import mutate_float from tensorneat.common import mutate_float
from .base import BaseConnGene from .base import BaseConn
class DefaultConnGene(BaseConnGene): class DefaultConn(BaseConn):
"Default connection gene, with the same behavior as in NEAT-python." "Default connection gene, with the same behavior as in NEAT-python."
custom_attrs = ["weight"] custom_attrs = ["weight"]
@@ -14,9 +14,9 @@ class DefaultConnGene(BaseConnGene):
self, self,
weight_init_mean: float = 0.0, weight_init_mean: float = 0.0,
weight_init_std: float = 1.0, weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5, weight_mutate_power: float = 0.15,
weight_mutate_rate: float = 0.8, weight_mutate_rate: float = 0.2,
weight_replace_rate: float = 0.1, weight_replace_rate: float = 0.015,
): ):
super().__init__() super().__init__()
self.weight_init_mean = weight_init_mean self.weight_init_mean = weight_init_mean

View File

@@ -1,3 +1,3 @@
from .base import BaseNodeGene from .base import BaseNode
from .default import DefaultNodeGene from .default import DefaultNode
from .bias import BiasNode from .bias import BiasNode

View File

@@ -2,7 +2,7 @@ import jax, jax.numpy as jnp
from .. import BaseGene from .. import BaseGene
class BaseNodeGene(BaseGene): class BaseNode(BaseGene):
"Base class for node genes." "Base class for node genes."
fixed_attrs = ["index"] fixed_attrs = ["index"]

View File

@@ -1,5 +1,6 @@
from typing import Tuple from typing import Union, Sequence, Callable, Optional
import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import sympy as sp import sympy as sp
from tensorneat.common import ( from tensorneat.common import (
@@ -12,10 +13,10 @@ from tensorneat.common import (
convert_to_sympy, convert_to_sympy,
) )
from . import BaseNodeGene from . import BaseNode
class BiasNode(BaseNodeGene): class BiasNode(BaseNode):
""" """
Default node gene, with the same behavior as in NEAT-python. Default node gene, with the same behavior as in NEAT-python.
The attribute response is removed. The attribute response is removed.
@@ -27,31 +28,46 @@ class BiasNode(BaseNodeGene):
self, self,
bias_init_mean: float = 0.0, bias_init_mean: float = 0.0,
bias_init_std: float = 1.0, bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5, bias_mutate_power: float = 0.15,
bias_mutate_rate: float = 0.7, bias_mutate_rate: float = 0.2,
bias_replace_rate: float = 0.1, bias_replace_rate: float = 0.015,
aggregation_default: callable = Agg.sum, bias_lower_bound: float = -5,
aggregation_options: Tuple = (Agg.sum,), bias_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_replace_rate: float = 0.1, aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid, activation_default: Optional[Callable] = None,
activation_options: Tuple = (Act.sigmoid,), activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_replace_rate: float = 0.1, activation_replace_rate: float = 0.1,
): ):
super().__init__() super().__init__()
if isinstance(aggregation_options, Callable):
aggregation_options = [aggregation_options]
if isinstance(activation_options, Callable):
activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate self.bias_replace_rate = bias_replace_rate
self.bias_lower_bound = bias_lower_bound
self.bias_upper_bound = bias_upper_bound
self.aggregation_default = aggregation_options.index(aggregation_default) self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default) self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options)) self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state): def new_identity_attrs(self, state):
@@ -62,6 +78,7 @@ class BiasNode(BaseNodeGene):
def new_random_attrs(self, state, randkey): def new_random_attrs(self, state, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3) k1, k2, k3 = jax.random.split(randkey, num=3)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
agg = jax.random.choice(k2, self.aggregation_indices) agg = jax.random.choice(k2, self.aggregation_indices)
act = jax.random.choice(k3, self.activation_indices) act = jax.random.choice(k3, self.activation_indices)
@@ -80,7 +97,7 @@ class BiasNode(BaseNodeGene):
self.bias_mutate_rate, self.bias_mutate_rate,
self.bias_replace_rate, self.bias_replace_rate,
) )
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
agg = mutate_int( agg = mutate_int(
k2, agg, self.aggregation_indices, self.aggregation_replace_rate k2, agg, self.aggregation_indices, self.aggregation_replace_rate
) )

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Union, Sequence, Callable from typing import Optional, Union, Sequence, Callable
import numpy as np import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
@@ -14,10 +14,10 @@ from tensorneat.common import (
convert_to_sympy, convert_to_sympy,
) )
from . import BaseNodeGene from .base import BaseNode
class DefaultNodeGene(BaseNodeGene): class DefaultNode(BaseNode):
"Default node gene, with the same behavior as in NEAT-python." "Default node gene, with the same behavior as in NEAT-python."
custom_attrs = ["bias", "response", "aggregation", "activation"] custom_attrs = ["bias", "response", "aggregation", "activation"]
@@ -26,18 +26,22 @@ class DefaultNodeGene(BaseNodeGene):
self, self,
bias_init_mean: float = 0.0, bias_init_mean: float = 0.0,
bias_init_std: float = 1.0, bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5, bias_mutate_power: float = 0.15,
bias_mutate_rate: float = 0.7, bias_mutate_rate: float = 0.2,
bias_replace_rate: float = 0.1, bias_replace_rate: float = 0.015,
bias_lower_bound: float = -5,
bias_upper_bound: float = 5,
response_init_mean: float = 1.0, response_init_mean: float = 1.0,
response_init_std: float = 0.0, response_init_std: float = 0.0,
response_mutate_power: float = 0.5, response_mutate_power: float = 0.15,
response_mutate_rate: float = 0.7, response_mutate_rate: float = 0.2,
response_replace_rate: float = 0.1, response_replace_rate: float = 0.015,
aggregation_default: Callable = Agg.sum, response_lower_bound: float = -5,
response_upper_bound: float = 5,
aggregation_default: Optional[Callable] = None,
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum, aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
aggregation_replace_rate: float = 0.1, aggregation_replace_rate: float = 0.1,
activation_default: Callable = Act.sigmoid, activation_default: Optional[Callable] = None,
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid, activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
activation_replace_rate: float = 0.1, activation_replace_rate: float = 0.1,
): ):
@@ -48,17 +52,26 @@ class DefaultNodeGene(BaseNodeGene):
if isinstance(activation_options, Callable): if isinstance(activation_options, Callable):
activation_options = [activation_options] activation_options = [activation_options]
if len(aggregation_options) == 1 and aggregation_default is None:
aggregation_default = aggregation_options[0]
if len(activation_options) == 1 and activation_default is None:
activation_default = activation_options[0]
self.bias_init_mean = bias_init_mean self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate self.bias_replace_rate = bias_replace_rate
self.bias_lower_bound = bias_lower_bound
self.bias_upper_bound = bias_upper_bound
self.response_init_mean = response_init_mean self.response_init_mean = response_init_mean
self.response_init_std = response_init_std self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate self.response_replace_rate = response_replace_rate
self.reponse_lower_bound = response_lower_bound
self.response_upper_bound = response_upper_bound
self.aggregation_default = aggregation_options.index(aggregation_default) self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options self.aggregation_options = aggregation_options
@@ -71,16 +84,21 @@ class DefaultNodeGene(BaseNodeGene):
self.activation_replace_rate = activation_replace_rate self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state): def new_identity_attrs(self, state):
return jnp.array( bias = 0
[0, 1, self.aggregation_default, -1] res = 1
) # activation=-1 means Act.identity agg = self.aggregation_default
act = self.activation_default
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
def new_random_attrs(self, state, randkey): def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4) k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
res = ( res = (
jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean
) )
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
agg = jax.random.choice(k3, self.aggregation_indices) agg = jax.random.choice(k3, self.aggregation_indices)
act = jax.random.choice(k4, self.activation_indices) act = jax.random.choice(k4, self.activation_indices)
@@ -98,7 +116,7 @@ class DefaultNodeGene(BaseNodeGene):
self.bias_mutate_rate, self.bias_mutate_rate,
self.bias_replace_rate, self.bias_replace_rate,
) )
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
res = mutate_float( res = mutate_float(
k2, k2,
res, res,
@@ -108,7 +126,7 @@ class DefaultNodeGene(BaseNodeGene):
self.response_mutate_rate, self.response_mutate_rate,
self.response_replace_rate, self.response_replace_rate,
) )
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
agg = mutate_int( agg = mutate_int(
k4, agg, self.aggregation_indices, self.aggregation_replace_rate k4, agg, self.aggregation_indices, self.aggregation_replace_rate
) )

View File

@@ -23,9 +23,9 @@ from ...utils import (
class DefaultMutation(BaseMutation): class DefaultMutation(BaseMutation):
def __init__( def __init__(
self, self,
conn_add: float = 0.2, conn_add: float = 0.1,
conn_delete: float = 0, conn_delete: float = 0,
node_add: float = 0.2, node_add: float = 0.1,
node_delete: float = 0, node_delete: float = 0,
): ):
self.conn_add = conn_add self.conn_add = conn_add

View File

@@ -3,7 +3,7 @@ from jax import vmap, numpy as jnp
from .utils import unflatten_conns from .utils import unflatten_conns
from .base import BaseGenome from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene from .gene import DefaultNode, DefaultConn
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -20,8 +20,8 @@ class RecurrentGenome(BaseGenome):
num_outputs: int, num_outputs: int,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=DefaultNodeGene(), node_gene=DefaultNode(),
conn_gene=DefaultConnGene(), conn_gene=DefaultConn(),
mutation=DefaultMutation(), mutation=DefaultMutation(),
crossover=DefaultCrossover(), crossover=DefaultCrossover(),
distance=DefaultDistance(), distance=DefaultDistance(),