又搞到3点,还是没有找到问题在哪,不过已经排除了是forward的问题

This commit is contained in:
wls2002
2023-05-07 02:59:48 +08:00
parent 414b620dc8
commit d1f54022bd
16 changed files with 772 additions and 58 deletions

View File

@@ -6,7 +6,12 @@ import numpy as np
from .species import SpeciesController
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
from .genome.numpy import batch_crossover
from .genome.numpy import expand, expand_single
from .genome.numpy import expand, expand_single, pop_analysis
from .genome.origin_neat import *
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
class Pipeline:
@@ -14,8 +19,7 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
np.random.seed(seed)
def __init__(self, config):
self.config = config
self.N = config.basic.init_maximum_nodes
@@ -48,6 +52,15 @@ class Pipeline:
return func
def tell(self, fitnesses):
# idx = np.argmax(fitnesses)
# print(f"argmax: {idx}, max: {np.max(fitnesses)}, a_max: {fitnesses[idx]}")
# n, c = self.pop_nodes[idx], self.pop_connections[idx]
# func = create_forward_function(n, c, self.N, self.input_idx, self.output_idx, batch=True)
# out = func(xor_inputs)
# print(f"max fitness: {fitnesses[idx]}")
# print(f"real fitness: {4 - np.sum(np.abs(out - xor_outputs), axis=0)}")
# print(f"Out:\n{func(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]))}")
self.generation += 1
self.species_controller.update_species_fitnesses(fitnesses)
@@ -56,12 +69,31 @@ class Pipeline:
self.update_next_generation(crossover_pair)
# print(pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx))
analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
try:
for nodes, connections in zip(self.pop_nodes, self.pop_connections):
g = array2object(self.config, nodes, connections)
print(g)
net = FeedForwardNetwork.create(g)
real_out = [net.activate(x) for x in xor_inputs]
func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
out = func(xor_inputs)
real_out = np.array(real_out)
out = np.array(out)
print(real_out, out)
assert np.allclose(real_out, out)
except AssertionError:
np.save("err_nodes.npy", self.pop_nodes)
np.save("err_connections.npy", self.pop_connections)
# print(g)
self.species_controller.speciate(self.pop_nodes, self.pop_connections, self.generation)
self.expand()
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True)
@@ -77,6 +109,7 @@ class Pipeline:
self.tell(fitnesses)
print("Generation limit reached!")
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
"""
create the next generation
@@ -105,6 +138,7 @@ class Pipeline:
# mutate
new_node_keys = np.array(self.fetch_new_node_keys())
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
@@ -122,6 +156,7 @@ class Pipeline:
unused.append(key)
self.new_node_keys_pool = unused + self.new_node_keys_pool
def expand(self):
"""
Expand the population if needed.
@@ -133,14 +168,15 @@ class Pipeline:
pop_node_sizes = np.sum(~np.isnan(pop_node_keys), axis=1)
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N:
print(f"expand to {self.N}!")
self.N = int(self.N * self.expand_coe)
print(f"expand to {self.N}!")
self.pop_nodes, self.pop_connections = expand(self.pop_nodes, self.pop_connections, self.N)
# don't forget to expand representation genome in species
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N)
def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys
if len(self.new_node_keys_pool) < self.pop_size:
@@ -153,6 +189,7 @@ class Pipeline:
self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:]
return res
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
@@ -162,4 +199,4 @@ class Pipeline:
self.generation_timestamp = new_timestamp
print(f"Generation: {self.generation}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Species sizes: {species_sizes}, Cost time: {cost_time}")