add save function in pipeline

This commit is contained in:
wls2002
2024-06-16 21:47:53 +08:00
parent b9d6482d11
commit fb2ae5d2fa
10 changed files with 94 additions and 164 deletions

View File

@@ -1,5 +1,8 @@
import json
import os
import jax, jax.numpy as jnp
import time
import datetime, time
import numpy as np
from algorithm import BaseAlgorithm
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
save_path=None,
save_dir=None,
is_save: bool = False,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
self.save_path = save_path
self.is_save = is_save
if is_save:
if save_dir is None:
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_dir = f"./{self.__class__.__name__} {now}"
else:
self.save_dir = save_dir
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
def setup(self, state=State()):
print("initializing")
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
state = self.algorithm.setup(state)
state = self.problem.setup(state)
if self.is_save:
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
f.write(json.dumps(self.show_config(), indent=4))
# create log file
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
print("initializing finished")
return state
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# save best if save path is not None
if self.save_path is not None:
best_genome = jax.device_get(self.best_genome)
with open(self.save_path, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)
# append log
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best)
self.problem.show(