add save function in pipeline
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import time
|
||||
import datetime, time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import BaseAlgorithm
|
||||
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
save_path=None,
|
||||
save_dir=None,
|
||||
is_save: bool = False,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
self.save_path = save_path
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
if save_dir is None:
|
||||
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||
else:
|
||||
self.save_dir = save_dir
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
|
||||
if self.is_save:
|
||||
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
|
||||
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
|
||||
f.write(json.dumps(self.show_config(), indent=4))
|
||||
# create log file
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
|
||||
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
|
||||
|
||||
print("initializing finished")
|
||||
return state
|
||||
|
||||
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
if self.save_path is not None:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(self.save_path, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
# append log
|
||||
if self.is_save:
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
self.problem.show(
|
||||
|
||||
Reference in New Issue
Block a user