debuging
This commit is contained in:
@@ -7,7 +7,9 @@ import numpy as np
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome.crossover import crossover
|
||||
from .genome import expand, expand_single
|
||||
from algorithms.neat.genome.genome import pop_analysis, analysis
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -51,12 +53,22 @@ class Pipeline:
|
||||
def tell(self, fitnesses):
|
||||
self.generation += 1
|
||||
|
||||
for i, f in enumerate(fitnesses):
|
||||
if np.isnan(f):
|
||||
print("fuck!!!!!!!!!!!!!!")
|
||||
error_nodes, error_connections = self.pop_nodes[i], self.pop_connections[i]
|
||||
np.save('error_nodes.npy', error_nodes)
|
||||
np.save('error_connections.npy', error_connections)
|
||||
assert False
|
||||
|
||||
self.species_controller.update_species_fitnesses(fitnesses)
|
||||
|
||||
crossover_pair = self.species_controller.reproduce(self.generation)
|
||||
|
||||
self.update_next_generation(crossover_pair)
|
||||
|
||||
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
|
||||
|
||||
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
|
||||
|
||||
self.expand()
|
||||
@@ -103,16 +115,22 @@ class Pipeline:
|
||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = crossover_wrapper(crossover_rand_keys, wpn, wpc, lpn, lpc)
|
||||
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
|
||||
|
||||
# mutate
|
||||
new_node_keys = np.array(self.fetch_new_node_keys())
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys)
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
|
||||
|
||||
# print(pop_analysis(m_npn, m_npc, self.input_idx, self.output_idx))
|
||||
|
||||
# elitism don't mutate
|
||||
# (pop_size, ) to (pop_size, 1, 1)
|
||||
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)
|
||||
# (pop_size, ) to (pop_size, 1, 1, 1)
|
||||
self.pop_connections = np.where(elitism_mask[:, None, None, None], npc, m_npc)
|
||||
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
|
||||
|
||||
# recycle unused node keys
|
||||
unused = []
|
||||
@@ -138,8 +156,8 @@ class Pipeline:
|
||||
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
|
||||
|
||||
# don't forget to expand representation genome in species
|
||||
for s in self.species_controller.species:
|
||||
s.representative = expand(*s.representative, self.N)
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N)
|
||||
|
||||
def fetch_new_node_keys(self):
|
||||
# if remain unused keys are not enough, create new keys
|
||||
@@ -164,6 +182,19 @@ class Pipeline:
|
||||
print(f"Generation: {self.generation}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
|
||||
|
||||
# def crossover_wrapper(self, crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||
# pop_nodes, pop_connections = [], []
|
||||
# for randkey, wn, wc, ln, lc in zip(crossover_rand_keys, wpn, wpc, lpn, lpc):
|
||||
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||
# pop_nodes.append(new_nodes)
|
||||
# pop_connections.append(new_connections)
|
||||
# try:
|
||||
# print(analysis(new_nodes, new_connections, self.input_idx, self.output_idx))
|
||||
# except AssertionError:
|
||||
# new_nodes, new_connections = crossover(randkey, wn, wc, ln, lc)
|
||||
# return np.stack(pop_nodes), np.stack(pop_connections)
|
||||
|
||||
# return batch_crossover(*args)
|
||||
|
||||
def crossover_wrapper(*args):
|
||||
return batch_crossover(*args)
|
||||
|
||||
Reference in New Issue
Block a user