From cf47c5bb3847ab0dcf5d2c384c8b3c1f4c86a88a Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 8 May 2023 00:02:51 +0800 Subject: [PATCH] huge accelerate: delete recycle new keys --- algorithms/neat/pipeline.py | 38 +++++++++++-------------------------- examples/xor.py | 4 ++-- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/algorithms/neat/pipeline.py b/algorithms/neat/pipeline.py index 7223f12..4dd8208 100644 --- a/algorithms/neat/pipeline.py +++ b/algorithms/neat/pipeline.py @@ -30,11 +30,8 @@ class Pipeline: self.crossover_func = create_crossover_function(batch=True) self.generation = 0 - self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) - self.new_node_keys_pool: List[int] = [max(self.output_idx) + 1] - self.generation_timestamp = time.time() self.best_fitness = float('-inf') @@ -107,24 +104,23 @@ class Pipeline: # mutate mutate_rand_keys = jax.random.split(k2, self.pop_size) - new_node_keys = np.array(self.fetch_new_node_keys()) + new_node_keys = np.arange(self.generation * self.pop_size, self.generation * self.pop_size + self.pop_size) m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes - m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) + # elitism don't mutate # (pop_size, ) to (pop_size, 1, 1) - self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) - # (pop_size, ) to (pop_size, 1, 1, 1) - self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + + def aux_function1(): + nonlocal m_npn, m_npc + m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) + self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) + # (pop_size, ) to (pop_size, 1, 1, 1) + self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) + # print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)) - # recycle unused node keys - unused = [] - for i, nodes in enumerate(self.pop_nodes): - node_keys, key = nodes[:, 0], new_node_keys[i] - if not np.isin(key, node_keys): # the new node key is not used - unused.append(key) - self.new_node_keys_pool = unused + self.new_node_keys_pool + aux_function1() def expand(self): """ @@ -145,18 +141,6 @@ class Pipeline: for s in self.species_controller.species.values(): s.representative = expand_single(*s.representative, self.N) - def fetch_new_node_keys(self): - # if remain unused keys are not enough, create new keys - if len(self.new_node_keys_pool) < self.pop_size: - max_unused_key = max(self.new_node_keys_pool) if self.new_node_keys_pool else -1 - new_keys = list(range(max_unused_key + 1, max_unused_key + 1 + 10 * self.pop_size)) - self.new_node_keys_pool.extend(new_keys) - - # fetch keys from pool - res = self.new_node_keys_pool[:self.pop_size] - self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:] - return res - def default_analysis(self, fitnesses): max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) species_sizes = [len(s.members) for s in self.species_controller.species.values()] diff --git a/examples/xor.py b/examples/xor.py index 895bd5f..e8ac80c 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]: return fitnesses.tolist() # returns a list -# @using_cprofile -@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") +@using_cprofile +# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") def main(): config = Configer.load_config() pipeline = Pipeline(config, seed=11323)