change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

@@ -5,7 +5,8 @@ import jax
from jax import vmap, jit
import numpy as np
from algorithm import Algorithm
from config import Config
from core import Algorithm, Genome
class Pipeline:
@@ -13,11 +14,11 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, algorithm: Algorithm):
def __init__(self, config: Config, algorithm: Algorithm):
self.config = config
self.algorithm = algorithm
randkey = jax.random.PRNGKey(config['random_seed'])
randkey = jax.random.PRNGKey(config.basic.seed)
self.state = algorithm.setup(randkey)
self.best_genome = None
@@ -29,18 +30,18 @@ class Pipeline:
self.forward_func = jit(self.algorithm.forward)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell)
def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns)
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness):
self.state = self.tell_func(self.state, fitness)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
for _ in range(self.config.basic.generation_limit):
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
@@ -52,7 +53,7 @@ class Pipeline:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return self.best_genome
@@ -70,11 +71,11 @@ class Pipeline:
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx])
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
member_count = jax.device_get(self.state.species_info[:, 3])
member_count = jax.device_get(self.state.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")