accelerate: unify jnp and np

This commit is contained in:
wls2002
2023-05-08 00:46:48 +08:00
parent cf47c5bb38
commit c705b5cfe2
3 changed files with 13 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
import time import time
import jax import jax
import jax.numpy as jnp
import numpy as np import numpy as np
from .species import SpeciesController from .species import SpeciesController
@@ -16,6 +17,7 @@ class Pipeline:
""" """
def __init__(self, config, seed=42): def __init__(self, config, seed=42):
self.generation_timestamp = time.time()
self.randkey = jax.random.PRNGKey(seed) self.randkey = jax.random.PRNGKey(seed)
self.config = config self.config = config
@@ -32,7 +34,6 @@ class Pipeline:
self.generation = 0 self.generation = 0
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation) self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.generation_timestamp = time.time()
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
def ask(self, batch: bool): def ask(self, batch: bool):
@@ -87,6 +88,7 @@ class Pipeline:
# crossover # crossover
# prepare elitism mask and crossover pair # prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False) elitism_mask = np.full(self.pop_size, False)
for i, pair in enumerate(crossover_pair): for i, pair in enumerate(crossover_pair):
if not isinstance(pair, tuple): # elitism if not isinstance(pair, tuple): # elitism
elitism_mask[i] = True elitism_mask[i] = True
@@ -94,13 +96,14 @@ class Pipeline:
crossover_pair = np.array(crossover_pair) crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size) crossover_rand_keys = jax.random.split(k1, self.pop_size)
# batch crossover # batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
lpc) # new pop nodes, new pop connections
npn, npc = jax.device_get(npn), jax.device_get(npc)
# mutate # mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size) mutate_rand_keys = jax.random.split(k2, self.pop_size)
@@ -111,16 +114,12 @@ class Pipeline:
# elitism don't mutate # elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1) # (pop_size, ) to (pop_size, 1, 1)
def aux_function1(): m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
nonlocal m_npn, m_npc self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc) # (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn) self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# (pop_size, ) to (pop_size, 1, 1, 1)
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
aux_function1()
def expand(self): def expand(self):
""" """

View File

@@ -68,6 +68,7 @@ class SpeciesController:
# calculate the distance between the representative and the population # calculate the distance between the representative and the population
r_nodes, r_connections = species.representative r_nodes, r_connections = species.representative
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = jax.device_get(distances)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx new_representatives[sid] = min_idx
@@ -80,7 +81,7 @@ class SpeciesController:
if previous_species_list: # exist previous species if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list] rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [ res_pop_distance = [
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
for rid in rid_list for rid in rid_list
] ]

View File

@@ -2,7 +2,7 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 10, "init_maximum_nodes": 25,
"expands_coe": 2 "expands_coe": 2
}, },
"neat": { "neat": {