debug-branch

This commit is contained in:
wls2002
2023-05-06 21:04:28 +08:00
parent 14fed83193
commit a85e6eba78
20 changed files with 1719 additions and 233 deletions

View File

@@ -1,15 +1,12 @@
from typing import List, Union, Tuple, Callable
import time
import jax
import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome.crossover import crossover
from .genome import expand, expand_single
from algorithms.neat.genome.genome import pop_analysis, analysis
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
from .genome.numpy import batch_crossover
from .genome.numpy import expand, expand_single
class Pipeline:
@@ -18,7 +15,7 @@ class Pipeline:
"""
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
self.config = config
self.N = config.basic.init_maximum_nodes
@@ -53,14 +50,6 @@ class Pipeline:
def tell(self, fitnesses):
self.generation += 1
for i, f in enumerate(fitnesses):
if np.isnan(f):
print("fuck!!!!!!!!!!!!!!")
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
np.save('error_nodes.npy', error_nodes)
np.save('error_connections.npy', error_connections)
assert False
self.species_controller.update_species_fitnesses(fitnesses)
crossover_pair = self.species_controller.reproduce(self.generation)
@@ -96,8 +85,6 @@ class Pipeline:
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
# crossover
# prepare elitism mask and crossover pair
elitism_mask = np.full(self.pop_size, False)
@@ -112,18 +99,13 @@ class Pipeline:
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
# mutate
new_node_keys = np.array(self.fetch_new_node_keys())
mutate_rand_keys = jax.random.split(k2, self.pop_size)
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
@@ -180,21 +162,4 @@ class Pipeline:
self.generation_timestamp = new_timestamp
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
# pop_nodes, pop_connections = [], []
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# pop_nodes.append(new_nodes)
# pop_connections.append(new_connections)
# try:
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
# except AssertionError:
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
# return np.stack(pop_nodes), np.stack(pop_connections)
# return batch_crossover(*args)
def crossover_wrapper(*args):
return batch_crossover(*args)
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")