accelerate: unify jnp and np
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
@@ -16,6 +17,7 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.generation_timestamp = time.time()
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
|
||||
self.config = config
|
||||
@@ -32,7 +34,6 @@ class Pipeline:
|
||||
self.generation = 0
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
self.best_fitness = float('-inf')
|
||||
|
||||
def ask(self, batch: bool):
|
||||
@@ -87,6 +88,7 @@ class Pipeline:
|
||||
# crossover
|
||||
# prepare elitism mask and crossover pair
|
||||
elitism_mask = np.full(self.pop_size, False)
|
||||
|
||||
for i, pair in enumerate(crossover_pair):
|
||||
if not isinstance(pair, tuple): # elitism
|
||||
elitism_mask[i] = True
|
||||
@@ -94,13 +96,14 @@ class Pipeline:
|
||||
crossover_pair = np.array(crossover_pair)
|
||||
|
||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = self.crossover_func(crossover_rand_keys, wpn, wpc, lpn,
|
||||
lpc) # new pop nodes, new pop connections
|
||||
npn, npc = jax.device_get(npn), jax.device_get(npc)
|
||||
|
||||
# mutate
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
@@ -111,16 +114,12 @@ class Pipeline:
|
||||
# elitism don't mutate
|
||||
# (pop_size, ) to (pop_size, 1, 1)
|
||||
|
||||
def aux_function1():
|
||||
nonlocal m_npn, m_npc
|
||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||
|
||||
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
|
||||
|
||||
aux_function1()
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
|
||||
@@ -68,6 +68,7 @@ class SpeciesController:
|
||||
# calculate the distance between the representative and the population
|
||||
r_nodes, r_connections = species.representative
|
||||
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
distances = jax.device_get(distances)
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
new_representatives[sid] = min_idx
|
||||
@@ -80,7 +81,7 @@ class SpeciesController:
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = [
|
||||
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
jax.device_get(self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections))
|
||||
for rid in rid_list
|
||||
]
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"basic": {
|
||||
"num_inputs": 2,
|
||||
"num_outputs": 1,
|
||||
"init_maximum_nodes": 10,
|
||||
"init_maximum_nodes": 25,
|
||||
"expands_coe": 2
|
||||
},
|
||||
"neat": {
|
||||
|
||||
Reference in New Issue
Block a user