hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -11,7 +11,7 @@ from core import Algorithm, Genome
class Pipeline:
"""
Neat algorithm pipeline.
Simple pipeline.
"""
def __init__(self, config: Config, algorithm: Algorithm):
@@ -38,7 +38,9 @@ class Pipeline:
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness):
self.state = self.tell_func(self.state, fitness)
# self.state = self.tell_func(self.state, fitness)
new_state = self.tell_func(self.state, fitness)
self.state = new_state
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.basic.generation_limit):
@@ -73,9 +75,9 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx]
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
member_count = jax.device_get(self.state.member_count)
member_count = jax.device_get(self.state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")