This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -71,6 +71,9 @@ class Pipeline(StatefulBaseClass):
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.genome_dir = os.path.join(self.save_dir, "genomes")
if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir)
def setup(self, state=State()):
print("initializing")
@@ -165,6 +168,7 @@ class Pipeline(StatefulBaseClass):
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile()
# compiled_step = self.step
print(
f"compile finished, cost time: {time.time() - tic:.6f}s",
)
@@ -181,9 +185,21 @@ class Pipeline(StatefulBaseClass):
if max(fitnesses) >= self.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
break
if self.algorithm.generation(state) >= self.generation_limit:
print("Generation limit reached!")
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
@@ -206,15 +222,15 @@ class Pipeline(StatefulBaseClass):
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
if self.is_save:
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# save best if save path is not None