diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index c9d77df..4c1085d 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -2,7 +2,6 @@ from typing import List, Union, Tuple, Callable import time import jax -import jax.numpy as jnp import numpy as np from .species import SpeciesController @@ -19,6 +18,7 @@ class Pipeline: def __init__(self, config, seed=42): self.generation_timestamp = time.time() self.randkey = jax.random.PRNGKey(seed) + np.random.seed(seed) self.config = config self.N = config.basic.init_maximum_nodes @@ -34,9 +34,6 @@ class Pipeline: self.generation = 0 self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) - # self.species_controller.speciate(self.pop_nodes, self.pop_connections, - # self.generation, self.o2o_distance, self.o2m_distance) - self.best_fitness = float('-inf') def ask(self, batch: bool): diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index feaf47f..f8b947b 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -239,7 +239,9 @@ class SpeciesController: self.species = {} # int -> idx in the pop_nodes, pop_connections of elitism # (int, int) -> the father and mother idx to be crossover + crossover_pair: List[Union[int, Tuple[int, int]]] = [] + for spawn, s in zip(spawn_amounts, remaining_species): assert spawn >= self.genome_elitism @@ -264,10 +266,8 @@ class SpeciesController: # only use good genomes to crossover sorted_members = sorted_members[:repro_cutoff] - # Randomly choose parents and produce the number of offspring allotted to the species. - for _ in range(spawn): - # allow to replace, for the case that the species only has one genome - c1, c2 = np.random.choice(len(sorted_members), size=2, replace=True) + list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True) + for c1, c2 in zip(list_idx1, list_idx2): idx1, fitness1 = sorted_members[c1], sorted_fitnesses[c1] idx2, fitness2 = sorted_members[c2], sorted_fitnesses[c2] if fitness1 >= fitness2: