remove useless codes

This commit is contained in:
wls2002
2023-05-07 16:30:26 +08:00
parent 890c928b0f
commit cec40b254f
24 changed files with 3 additions and 2601 deletions

View File

@@ -2,16 +2,12 @@ from typing import List, Union, Tuple, Callable
import time
import jax
import numpy as np
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import expand, expand_single, distance
from .genome.origin_neat import *
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
from .genome import expand, expand_single
class Pipeline:
@@ -62,38 +58,6 @@ class Pipeline:
self.update_next_generation(crossover_pair)
# for i in range(self.pop_size):
# for j in range(self.pop_size):
# n1, c1 = self.pop_nodes[i], self.pop_connections[i]
# n2, c2 = self.pop_nodes[j], self.pop_connections[j]
# g1 = array2object(self.config.neat, n1, c1)
# g2 = array2object(self.config.neat, n2, c2)
# d_real = g1.distance(g2)
# d = distance(n1, c1, n2, c2)
# print(d_real, d)
# assert np.allclose(d_real, d)
# analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
# try:
# for nodes, connections in zip(self.pop_nodes, self.pop_connections):
# g = array2object(self.config, nodes, connections)
# print(g)
# net = FeedForwardNetwork.create(g)
# real_out = [net.activate(x) for x in xor_inputs]
# func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
# out = func(xor_inputs)
# real_out = np.array(real_out)
# out = np.array(out)
# print(real_out, out)
# assert np.allclose(real_out, out)
# except AssertionError:
# np.save("err_nodes.npy", self.pop_nodes)
# np.save("err_connections.npy", self.pop_connections)
# print(g)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.expand()
@@ -145,7 +109,7 @@ class Pipeline:
new_node_keys = np.array(self.fetch_new_node_keys())
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = jax.device_get(m_npn), jax.device_get(m_npc)
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
self.pop_nodes = np.where(elitism_mask[:, None, None], npn, m_npn)