diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index 9bdbc82..d9f2079 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -13,7 +13,7 @@ class Species(object): self.created = generation self.last_improved = generation self.representative: Tuple[NDArray, NDArray] = (None, None) # (nodes, connections) - self.members: List[int] = [] # idx in pop_nodes, pop_connections, + self.members: NDArray = None # idx in pop_nodes, pop_connections, self.fitness = None self.member_fitnesses = None self.adjusted_fitness = None @@ -24,7 +24,7 @@ class Species(object): self.members = members def get_fitnesses(self, fitnesses): - return [fitnesses[m] for m in self.members] + return fitnesses[self.members] class SpeciesController: @@ -55,7 +55,7 @@ class SpeciesController: pop_size = pop_nodes.shape[0] species_id = next(self.species_idxer) s = Species(species_id, 0) - members = list(range(pop_size)) + members = np.array(list(range(pop_size))) s.update((pop_nodes[0], pop_connections[0]), members) self.species[species_id] = s @@ -81,6 +81,7 @@ class SpeciesController: for sid in previous_species_list ]) + # TODO: Use jit to wrapper function find_min_with_mask to accelerate this part for i, sid in enumerate(previous_species_list): distances = total_distances[i] min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance @@ -145,7 +146,7 @@ class SpeciesController: s = Species(sid, generation) self.species[sid] = s - members = new_members[sid] + members = np.array(new_members[sid]) s.update((pop_nodes[rid], pop_connections[rid]), members) def update_species_fitnesses(self, fitnesses): @@ -195,9 +196,10 @@ class SpeciesController: result.append((sid, s, is_stagnant)) return result - def reproduce(self, generation: int) -> List[Union[int, Tuple[int, int]]]: + def reproduce(self, fitnesses: NDArray, generation: int) -> Tuple[NDArray, NDArray, NDArray]: """ code modified from neat-python! + :param fitnesses: :param generation: :return: crossover_pair for next generation. # int -> idx in the pop_nodes, pop_connections of elitism @@ -208,21 +210,22 @@ class SpeciesController: # The average adjusted fitness scheme (normalized to the interval # [0, 1]) allows the use of negative fitness values without # interfering with the shared fitness scheme. - all_fitnesses = [] + + min_fitness = np.inf + max_fitness = -np.inf + remaining_species = [] for stag_sid, stag_s, stagnant in self.stagnation(generation): if not stagnant: - all_fitnesses.extend(stag_s.member_fitnesses) + min_fitness = min(min_fitness, np.min(stag_s.member_fitnesses)) + max_fitness = max(max_fitness, np.max(stag_s.member_fitnesses)) remaining_species.append(stag_s) # No species left. - if not remaining_species: - self.species = {} - return [] + assert remaining_species # Compute each species' member size in the next generation. - min_fitness = min(all_fitnesses) - max_fitness = max(all_fitnesses) + # Do not allow the fitness range to be zero, as we divide by it below. # TODO: The ``1.0`` below is rather arbitrary, and should be configurable. fitness_range = max(1.0, max_fitness - min_fitness) @@ -242,21 +245,23 @@ class SpeciesController: # int -> idx in the pop_nodes, pop_connections of elitism # (int, int) -> the father and mother idx to be crossover - crossover_pair: List[Union[int, Tuple[int, int]]] = [] + part1, part2, elite_mask = [], [], [] for spawn, s in zip(spawn_amounts, remaining_species): assert spawn >= self.genome_elitism # retain remain species to next generation - old_members, fitnesses = s.members, s.member_fitnesses + old_members, member_fitnesses = s.members, s.member_fitnesses s.members = [] self.species[s.key] = s # add elitism genomes to next generation - sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, fitnesses) + sorted_members, sorted_fitnesses = sort_element_with_fitnesses(old_members, member_fitnesses) if self.genome_elitism > 0: for m in sorted_members[:self.genome_elitism]: - crossover_pair.append(m) + part1.append(m) + part2.append(m) + elite_mask.append(True) spawn -= 1 if spawn <= 0: @@ -269,15 +274,16 @@ class SpeciesController: sorted_members = sorted_members[:repro_cutoff] list_idx1, list_idx2 = np.random.choice(len(sorted_members), size=(2, spawn), replace=True) - for c1, c2 in zip(list_idx1, list_idx2): - idx1, fitness1 = sorted_members[c1], sorted_fitnesses[c1] - idx2, fitness2 = sorted_members[c2], sorted_fitnesses[c2] - if fitness1 >= fitness2: - crossover_pair.append((idx1, idx2)) - else: - crossover_pair.append((idx2, idx1)) + part1.extend(sorted_members[list_idx1]) + part2.extend(sorted_members[list_idx2]) + elite_mask.extend([False] * spawn) - return crossover_pair + part1_fitness, part2_fitness = fitnesses[part1], fitnesses[part2] + is_part1_win = part1_fitness >= part2_fitness + winner_part = np.where(is_part1_win, part1, part2) + loser_part = np.where(is_part1_win, part2, part1) + + return winner_part, loser_part, np.array(elite_mask) def compute_spawn(adjusted_fitness, previous_sizes, pop_size, min_species_size): @@ -326,10 +332,7 @@ def find_min_with_mask(arr: NDArray, mask: NDArray) -> int: return min_idx -def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \ - -> Tuple[List[int], List[float]]: - combined = zip(members, fitnesses) - sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) - sorted_members = [item[0] for item in sorted_combined] - sorted_fitnesses = [item[1] for item in sorted_combined] - return sorted_members, sorted_fitnesses +def sort_element_with_fitnesses(members: NDArray, fitnesses: NDArray) \ + -> Tuple[NDArray, NDArray]: + sorted_idx = np.argsort(fitnesses)[::-1] + return members[sorted_idx], fitnesses[sorted_idx] \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 7bba04c..a8dba00 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -25,7 +25,7 @@ from problems import Sin, Xor, DIY @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() - config.neat.population.pop_size = 50 + # config.neat.population.pop_size = 50 problem = Xor() # problem = Sin() # problem = DIY(func=lambda x: (np.sin(x) + np.exp(x) - x ** 2) / (np.cos(x) + np.sqrt(x)) - np.log(x + 1)) diff --git a/problems/function_fitting/function_fitting_problem.py b/problems/function_fitting/function_fitting_problem.py index 6a1bfd9..3ca950b 100644 --- a/problems/function_fitting/function_fitting_problem.py +++ b/problems/function_fitting/function_fitting_problem.py @@ -19,7 +19,7 @@ class FunctionFittingProblem(Problem): outs = pop_batch_forward(self.inputs) outs = jax.device_get(outs) fitnesses = -np.mean((self.target - outs) ** 2, axis=(1, 2)) - return fitnesses.tolist() + return fitnesses def draw(self, batch_func): outs = batch_func(self.inputs) diff --git a/utils/default_config.json b/utils/default_config.json index ef5065e..66492b1 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -13,7 +13,7 @@ "fitness_criterion": "max", "fitness_threshold": -0.001, "generation_limit": 1000, - "pop_size": 30, + "pop_size": 1000, "reset_on_extinction": "False" }, "gene": {