fix bugs
This commit is contained in:
@@ -71,6 +71,9 @@ class Pipeline(StatefulBaseClass):
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
self.genome_dir = os.path.join(self.save_dir, "genomes")
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -165,6 +168,7 @@ class Pipeline(StatefulBaseClass):
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
# compiled_step = self.step
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
)
|
||||
@@ -181,9 +185,21 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
if max(fitnesses) >= self.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
break
|
||||
|
||||
if self.algorithm.generation(state) >= self.generation_limit:
|
||||
print("Generation limit reached!")
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
@@ -206,15 +222,15 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
with open(os.path.join(self.genome_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user