From acc9eab64ab572589ad93c52a14db214fb52ca3c Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 11 May 2023 08:15:06 +0800 Subject: [PATCH] change fitness from list to array optimize the code of reproduction. --- algorithms/neat/pipeline.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index ac3a2bc..c8009a5 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -53,9 +53,9 @@ class Pipeline: self.species_controller.update_species_fitnesses(fitnesses) - crossover_pair = self.species_controller.reproduce(self.generation) + winner_part, loser_part, elite_mask = self.species_controller.reproduce(fitnesses, self.generation) - self.update_next_generation(crossover_pair) + self.update_next_generation(winner_part, loser_part, elite_mask) self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation, self.o2o_distance, self.o2m_distance) @@ -82,36 +82,29 @@ class Pipeline: print("Generation limit reached!") return self.best_genome - def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: + def update_next_generation(self, winner_part, loser_part, elite_mask) -> None: """ - create the next generation - :param crossover_pair: created from self.reproduce() + create next generation + :param winner_part: + :param loser_part: + :param elite_mask: + :return: """ assert self.pop_nodes.shape[0] == self.pop_size k1, k2, self.randkey = jax.random.split(self.randkey, 3) - # crossover - # prepare elitism mask and crossover pair - elitism_mask = np.full(self.pop_size, False) - - for i, pair in enumerate(crossover_pair): - if not isinstance(pair, tuple): # elitism - elitism_mask[i] = True - crossover_pair[i] = (pair, pair) - crossover_pair = np.array(crossover_pair) - crossover_rand_keys = jax.random.split(k1, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size) # batch crossover - wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes - wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections - lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes - lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections + wpn = self.pop_nodes[winner_part] # winner pop nodes + wpc = self.pop_connections[winner_part] # winner pop connections + lpn = self.pop_nodes[loser_part] # loser pop nodes + lpc = self.pop_connections[loser_part] # loser pop connections + npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - # mutate new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) @@ -120,8 +113,8 @@ class Pipeline: # elitism don't mutate npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) - self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + self.pop_nodes = np.where(elite_mask[:, None, None], npn, m_npn) + self.pop_connections = np.where(elite_mask[:, None, None, None], npc, m_npc) def expand(self): """