debug-branch
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome.crossover import crossover
|
||||
from .genome import expand, expand_single
|
||||
from algorithms.neat.genome.genome import pop_analysis, analysis
|
||||
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome.numpy import batch_crossover
|
||||
from .genome.numpy import expand, expand_single
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -18,7 +15,7 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
def __init__(self, config, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.config = config
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
@@ -53,14 +50,6 @@ class Pipeline:
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
for i, f in enumerate(fitnesses):
|
||||
if np.isnan(f):
|
||||
print("fuck!!!!!!!!!!!!!!")
|
||||
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
|
||||
np.save('error_nodes.npy', error_nodes)
|
||||
np.save('error_connections.npy', error_connections)
|
||||
assert False
|
||||
|
||||
self.species_controller.update_species_fitnesses(fitnesses)
|
||||
|
||||
crossover_pair = self.species_controller.reproduce(self.generation)
|
||||
@@ -96,8 +85,6 @@ class Pipeline:
|
||||
|
||||
assert self.pop_nodes.shape[0] == self.pop_size
|
||||
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
# crossover
|
||||
# prepare elitism mask and crossover pair
|
||||
elitism_mask = np.full(self.pop_size, False)
|
||||
@@ -112,18 +99,13 @@ class Pipeline:
|
||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
|
||||
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
|
||||
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
|
||||
|
||||
# mutate
|
||||
new_node_keys = np.array(self.fetch_new_node_keys())
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
||||
|
||||
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
|
||||
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
# (pop_size, ) to (pop_size, 1, 1)
|
||||
@@ -180,21 +162,4 @@ class Pipeline:
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
|
||||
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||
# pop_nodes, pop_connections = [], []
|
||||
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||
# pop_nodes.append(new_nodes)
|
||||
# pop_connections.append(new_connections)
|
||||
# try:
|
||||
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
|
||||
# except AssertionError:
|
||||
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||
# return np.stack(pop_nodes), np.stack(pop_connections)
|
||||
|
||||
# return batch_crossover(*args)
|
||||
|
||||
def crossover_wrapper(*args):
|
||||
return batch_crossover(*args)
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
Reference in New Issue
Block a user