accelerate: unify jnp and np

This commit is contained in:
wls2002
2023-05-08 00:46:48 +08:00
parent cf47c5bb38
commit c705b5cfe2
3 changed files with 13 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
import time
import jax
import jax.numpy as jnp
import numpy as np
from .species import SpeciesController
@@ -16,6 +17,7 @@ class Pipeline:
"""
def __init__(self, config, seed=42):
self.generation_timestamp = time.time()
self.randkey = jax.random.PRNGKey(seed)
self.config = config
@@ -32,7 +34,6 @@ class Pipeline:
self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.generation_timestamp = time.time()
self.best_fitness = float('-inf')
def ask(self, batch: bool):
@@ -87,6 +88,7 @@ class Pipeline:
# crossover
# prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False)
for i, pair in enumerate(crossover_pair):
if not isinstance(pair, tuple): # elitism
elitism_mask[i] = True
@@ -94,13 +96,14 @@ class Pipeline:
crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections
npn, npc = jax.device_get(npn), jax.device_get(npc)
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
@@ -111,16 +114,12 @@ class Pipeline:
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
def aux_function1():
nonlocal m_npn, m_npc
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
aux_function1()
def expand(self):
"""