try to accelerate the speed of speciate

This commit is contained in:
wls2002
2023-05-08 18:41:19 +08:00
parent 8653f49826
commit ee6bb01eff
4 changed files with 12 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
import time import time
import jax import jax
import jax.numpy as jnp
import numpy as np import numpy as np
from .species import SpeciesController from .species import SpeciesController
@@ -104,7 +105,6 @@ class Pipeline:
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections lpc) # new pop nodes, new pop connections
npn, npc = jax.device_get(npn), jax.device_get(npc)
# mutate # mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size)
@@ -113,11 +113,8 @@ class Pipeline:
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate # elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1) npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
def expand(self): def expand(self):

View File

@@ -76,11 +76,13 @@ class SpeciesController:
new_representatives = {} new_representatives = {}
new_members = {} new_members = {}
for sid, species in self.species.items(): total_distances = jax.device_get([
# calculate the distance between the representative and the population o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
r_nodes, r_connections = species.representative for sid in previous_species_list
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) ])
distances = jax.device_get(distances)
for i, sid in enumerate(previous_species_list):
distances = total_distances[i]
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx new_representatives[sid] = min_idx

View File

@@ -27,7 +27,7 @@ def evaluate(forward_func: Callable) -> List[float]:
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") # @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config, seed=11323) pipeline = Pipeline(config, seed=114514)
pipeline.auto_run(evaluate) pipeline.auto_run(evaluate)

View File

@@ -2,7 +2,7 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 20, "init_maximum_nodes": 30,
"expands_coe": 2 "expands_coe": 2
}, },
"neat": { "neat": {
@@ -59,7 +59,7 @@
"node_delete_prob": 0.2 "node_delete_prob": 0.2
}, },
"species": { "species": {
"compatibility_threshold": 2.5, "compatibility_threshold": 3,
"species_fitness_func": "max", "species_fitness_func": "max",
"max_stagnation": 20, "max_stagnation": 20,
"species_elitism": 2, "species_elitism": 2,