diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 4c1085d..8196470 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable import time import jax +import jax.numpy as jnp import numpy as np from .species import SpeciesController @@ -104,7 +105,6 @@ class Pipeline: lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections - npn, npc = jax.device_get(npn), jax.device_get(npc) # mutate mutate_rand_keys = jax.random.split(k2, self.pop_size) @@ -113,11 +113,8 @@ class Pipeline: m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes # elitism don't mutate - # (pop_size, ) to (pop_size, 1, 1) - - m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) + npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc]) self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - # (pop_size, ) to (pop_size, 1, 1, 1) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) def expand(self): diff --git a/algorithms/neat/species.py b/algorithms/neat/species.py index f8b947b..035ef58 100644 --- a/algorithms/neat/species.py +++ b/algorithms/neat/species.py @@ -76,11 +76,13 @@ class SpeciesController: new_representatives = {} new_members = {} - for sid, species in self.species.items(): - # calculate the distance between the representative and the population - r_nodes, r_connections = species.representative - distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) - distances = jax.device_get(distances) + total_distances = jax.device_get([ + o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections) + for sid in previous_species_list + ]) + + for i, sid in enumerate(previous_species_list): + distances = total_distances[i] min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance new_representatives[sid] = min_idx diff --git a/examples/xor.py b/examples/xor.py index e8ac80c..209f169 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -27,7 +27,7 @@ def evaluate(forward_func: Callable) -> List[float]: # @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() - pipeline = Pipeline(config, seed=11323) + pipeline = Pipeline(config, seed=114514) pipeline.auto_run(evaluate) diff --git a/utils/default_config.json b/utils/default_config.json index 9f81bfb..f6e8506 100644 --- a/utils/default_config.json +++ b/utils/default_config.json @@ -2,7 +2,7 @@ "basic": { "num_inputs": 2, "num_outputs": 1, - "init_maximum_nodes": 20, + "init_maximum_nodes": 30, "expands_coe": 2 }, "neat": { @@ -59,7 +59,7 @@ "node_delete_prob": 0.2 }, "species": { - "compatibility_threshold": 2.5, + "compatibility_threshold": 3, "species_fitness_func": "max", "max_stagnation": 20, "species_elitism": 2,