From 5626fddf41e1daac809abb876090ac9e12e7d033 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sat, 25 May 2024 17:00:20 +0800 Subject: [PATCH] add params key into setup. --- tensorneat/algorithm/neat/ga/crossover/base.py | 2 +- tensorneat/algorithm/neat/ga/mutation/base.py | 2 +- tensorneat/algorithm/neat/gene/base.py | 2 +- tensorneat/algorithm/neat/genome/base.py | 2 +- tensorneat/algorithm/neat/species/base.py | 7 ++++--- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index c7244d3..1849deb 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -3,7 +3,7 @@ from utils import State class BaseCrossover: - def setup(self, state=State()): + def setup(self, key, state=State()): return state def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2): diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py index e56e8e6..2322f85 100644 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ b/tensorneat/algorithm/neat/ga/mutation/base.py @@ -3,7 +3,7 @@ from utils import State class BaseMutation: - def setup(self, state=State()): + def setup(self, key, state=State()): return state def __call__(self, state, key, genome, nodes, conns, new_node_key): diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 1110074..abb2d52 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -9,7 +9,7 @@ class BaseGene: def __init__(self): pass - def setup(self, state=State()): + def setup(self, key, state=State()): return state def new_attrs(self, state, key): diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index fd711ab..9a3ac81 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -24,7 +24,7 @@ class BaseGenome: self.node_gene = node_gene self.conn_gene = conn_gene - def setup(self, state=State()): + def setup(self, key, state=State()): return state def transform(self, state, nodes, conns): diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index f1294f2..682fc8f 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -1,8 +1,9 @@ from utils import State + class BaseSpecies: - def setup(self, randkey): - raise NotImplementedError + def setup(self, key, state=State()): + return state def ask(self, state: State): raise NotImplementedError @@ -11,4 +12,4 @@ class BaseSpecies: raise NotImplementedError def speciate(self, state, generation): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError