diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index d8074f8..c9d77df 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -32,8 +32,10 @@ class Pipeline: self.compile_functions(debug=True) self.generation = 0 - self.species_controller.speciate(self.pop_nodes, self.pop_connections, - self.generation, self.o2o_distance, self.o2m_distance) + self.species_controller.init_speciate(self.pop_nodes, self.pop_connections) + + # self.species_controller.speciate(self.pop_nodes, self.pop_connections, + # self.generation, self.o2o_distance, self.o2m_distance) self.best_fitness = float('-inf') diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index aa8afd6..c4a5f48 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -45,6 +45,20 @@ class SpeciesController: self.species_idxer = count(0) self.species: Dict[int, Species] = {} # species_id -> species + def init_speciate(self, pop_nodes: NDArray, pop_connections: NDArray): + """ + speciate for the first generation + :param pop_connections: + :param pop_nodes: + :return: + """ + pop_size = pop_nodes.shape[0] + species_id = next(self.species_idxer) + s = Species(species_id, 0) + members = list(range(pop_size)) + s.update((pop_nodes[0], pop_connections[0]), members) + self.species[species_id] = s + def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int, o2o_distance: Callable, o2m_distance: Callable) -> None: """ diff --git a/examples/xor.py b/examples/xor.py index 895bd5f..e8ac80c 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]: return fitnesses.tolist() # returns a list -# @using_cprofile -@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +@using_cprofile +# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() pipeline = Pipeline(config, seed=11323)