try to accelerate the speed of speciate

This commit is contained in:
wls2002
2023-05-08 18:41:19 +08:00
parent 8653f49826
commit ee6bb01eff
4 changed files with 12 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
import time
import jax
import jax.numpy as jnp
import numpy as np
from .species import SpeciesController
@@ -104,7 +105,6 @@ class Pipeline:
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections
npn, npc = jax.device_get(npn), jax.device_get(npc)
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
@@ -113,11 +113,8 @@ class Pipeline:
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
npn, npc, m_npn, m_npc = jax.device_get([npn, npc, m_npn, m_npc])
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
def expand(self):