From 47b1cacb57815429e0fcbc9068c18f0ec0e00468 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 31 May 2024 15:36:47 +0800 Subject: [PATCH] fix bug for record the best genome. --- tensorneat/algorithm/neat/species/default.py | 3 --- tensorneat/examples/func_fit/xor.py | 8 ++++---- tensorneat/pipeline.py | 5 ++++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 83226f4..5dcdd5a 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -113,9 +113,6 @@ class DefaultSpecies(BaseSpecies): return state.pop_nodes, state.pop_conns def update_species(self, state, fitness): - # set nan to -inf - fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness) - # update the fitness of each species state, species_fitness = self.update_species_fitness(state, fitness) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 8e8ffd8..aaf83a2 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -19,10 +19,10 @@ if __name__ == "__main__": ), output_transform=Act.sigmoid, # the activation function for output node mutation=DefaultMutation( - node_add=0.05, - conn_add=0.05, - node_delete=0.05, - conn_delete=0.05, + node_add=0.1, + conn_add=0.1, + node_delete=0.1, + conn_delete=0.1, ), ), pop_size=1000, diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 8eca436..6e37744 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -116,6 +116,9 @@ class Pipeline: state, keys, self.algorithm.forward, pop_transformed ) + # replace nan with -inf + fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) + state = self.algorithm.tell(state, fitnesses) return state.update(randkey=randkey), fitnesses @@ -149,7 +152,7 @@ class Pipeline: def analysis(self, state, pop, fitnesses): - valid_fitnesses = fitnesses[~np.isnan(fitnesses)] + valid_fitnesses = fitnesses[~np.isinf(fitnesses)] max_f, min_f, mean_f, std_f = ( max(valid_fitnesses),