adjust parameter for xor problem

This commit is contained in:
wls2002
2023-05-07 16:21:41 +08:00
parent a3b9bca866
commit 890c928b0f
6 changed files with 23 additions and 18 deletions

View File

@@ -6,7 +6,7 @@ import jax
from .species import SpeciesController
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import expand, expand_single, pop_analysis
from .genome import expand, expand_single, distance
from .genome.origin_neat import *
@@ -53,14 +53,6 @@ class Pipeline:
return func
def tell(self, fitnesses):
# idx = np.argmax(fitnesses)
# print(f"argmax: {idx}, max: {np.max(fitnesses)}, a_max: {fitnesses[idx]}")
# n, c = self.pop_nodes[idx], self.pop_connections[idx]
# func = create_forward_function(n, c, self.N, self.input_idx, self.output_idx, batch=True)
# out = func(xor_inputs)
# print(f"max fitness: {fitnesses[idx]}")
# print(f"real fitness: {4 - np.sum(np.abs(out - xor_outputs), axis=0)}")
# print(f"Out:\n{func(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]))}")
self.generation += 1
@@ -70,6 +62,18 @@ class Pipeline:
self.update_next_generation(crossover_pair)
# for i in range(self.pop_size):
# for j in range(self.pop_size):
# n1, c1 = self.pop_nodes[i], self.pop_connections[i]
# n2, c2 = self.pop_nodes[j], self.pop_connections[j]
# g1 = array2object(self.config.neat, n1, c1)
# g2 = array2object(self.config.neat, n2, c2)
# d_real = g1.distance(g2)
# d = distance(n1, c1, n2, c2)
# print(d_real, d)
# assert np.allclose(d_real, d)
# analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
# try: