try to accelerate the speed of speciate
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .species import SpeciesController
|
from .species import SpeciesController
|
||||||
@@ -104,7 +105,6 @@ class Pipeline:
|
|||||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||||
lpc) # new pop nodes, new pop connections
|
lpc) # new pop nodes, new pop connections
|
||||||
npn, npc = jax.device_get(npn), jax.device_get(npc)
|
|
||||||
|
|
||||||
# mutate
|
# mutate
|
||||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||||
@@ -113,11 +113,8 @@ class Pipeline:
|
|||||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
# (pop_size, ) to (pop_size, 1, 1)
|
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
|
||||||
|
|
||||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
|
||||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
|
||||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
|
|||||||
@@ -76,11 +76,13 @@ class SpeciesController:
|
|||||||
new_representatives = {}
|
new_representatives = {}
|
||||||
new_members = {}
|
new_members = {}
|
||||||
|
|
||||||
for sid, species in self.species.items():
|
total_distances = jax.device_get([
|
||||||
# calculate the distance between the representative and the population
|
o2m_distance(*self.species[sid].representative, pop_nodes, pop_connections)
|
||||||
r_nodes, r_connections = species.representative
|
for sid in previous_species_list
|
||||||
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
|
])
|
||||||
distances = jax.device_get(distances)
|
|
||||||
|
for i, sid in enumerate(previous_species_list):
|
||||||
|
distances = total_distances[i]
|
||||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||||
|
|
||||||
new_representatives[sid] = min_idx
|
new_representatives[sid] = min_idx
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def evaluate(forward_func: Callable) -> List[float]:
|
|||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
pipeline = Pipeline(config, seed=11323)
|
pipeline = Pipeline(config, seed=114514)
|
||||||
pipeline.auto_run(evaluate)
|
pipeline.auto_run(evaluate)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"basic": {
|
"basic": {
|
||||||
"num_inputs": 2,
|
"num_inputs": 2,
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"init_maximum_nodes": 20,
|
"init_maximum_nodes": 30,
|
||||||
"expands_coe": 2
|
"expands_coe": 2
|
||||||
},
|
},
|
||||||
"neat": {
|
"neat": {
|
||||||
@@ -59,7 +59,7 @@
|
|||||||
"node_delete_prob": 0.2
|
"node_delete_prob": 0.2
|
||||||
},
|
},
|
||||||
"species": {
|
"species": {
|
||||||
"compatibility_threshold": 2.5,
|
"compatibility_threshold": 3,
|
||||||
"species_fitness_func": "max",
|
"species_fitness_func": "max",
|
||||||
"max_stagnation": 20,
|
"max_stagnation": 20,
|
||||||
"species_elitism": 2,
|
"species_elitism": 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user