hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Union
@dataclass(frozen=True)
@@ -7,34 +6,31 @@ class BasicConfig:
seed: int = 42
fitness_target: float = 1
generation_limit: int = 1000
num_inputs: int = 2
num_outputs: int = 1
pop_size: int = 100
def __post_init__(self):
assert self.num_inputs > 0, "the inputs number of the problem must be greater than 0"
assert self.num_outputs > 0, "the outputs number of the problem must be greater than 0"
assert self.pop_size > 0, "the population size must be greater than 0"
@dataclass(frozen=True)
class NeatConfig:
network_type: str = "feedforward"
activate_times: Union[int, None] = None # None means the network is feedforward
maximum_nodes: int = 100
maximum_conns: int = 50
inputs: int = 2
outputs: int = 1
maximum_nodes: int = 50
maximum_conns: int = 100
maximum_species: int = 10
# genome config
compatibility_disjoint: float = 1
compatibility_weight: float = 0.5
conn_add: float = 0.4
conn_delete: float = 0.4
conn_delete: float = 0
node_add: float = 0.2
node_delete: float = 0.2
node_delete: float = 0
# species config
compatibility_threshold: float = 3.0
compatibility_threshold: float = 3.5
species_elitism: int = 2
max_stagnation: int = 15
genome_elitism: int = 2
@@ -44,11 +40,9 @@ class NeatConfig:
def __post_init__(self):
assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent"
if self.network_type == "feedforward":
assert self.activate_times is None, "the activate times of feedforward network must be None"
else:
assert isinstance(self.activate_times, int), "the activate times of recurrent network must be int"
assert self.activate_times > 0, "the activate times of recurrent network must be greater than 0"
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0"
assert self.maximum_conns > 0, "the maximum connections must be greater than 0"
@@ -56,10 +50,10 @@ class NeatConfig:
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
assert self.conn_add > 0, "the connection add probability must be greater than 0"
assert self.conn_delete > 0, "the connection delete probability must be greater than 0"
assert self.node_add > 0, "the node add probability must be greater than 0"
assert self.node_delete > 0, "the node delete probability must be greater than 0"
assert self.conn_add >= 0, "the connection add probability must be greater than 0"
assert self.conn_delete >= 0, "the connection delete probability must be greater than 0"
assert self.node_add >= 0, "the node add probability must be greater than 0"
assert self.node_delete >= 0, "the node delete probability must be greater than 0"
assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0"
assert self.species_elitism > 0, "the species elitism must be greater than 0"
@@ -77,18 +71,21 @@ class HyperNeatConfig:
activation: str = "sigmoid"
aggregation: str = "sum"
activate_times: int = 5
inputs: int = 2
outputs: int = 1
def __post_init__(self):
assert self.below_threshold > 0, "the below threshold must be greater than 0"
assert self.max_weight > 0, "the max weight must be greater than 0"
assert self.activate_times > 0, "the activate times must be greater than 0"
assert self.inputs > 0, "the inputs number of hyper neat must be greater than 0"
assert self.outputs > 0, "the outputs number of hyper neat must be greater than 0"
@dataclass(frozen=True)
class GeneConfig:
pass
@dataclass(frozen=True)
class SubstrateConfig:
pass