add show_details in problem;

releated to https://github.com/EMI-Group/tensorneat/issues/15
This commit is contained in:
wls2002
2025-02-12 22:42:05 +08:00
parent de2d906656
commit e4f855b4f6
3 changed files with 126 additions and 28 deletions

View File

@@ -20,6 +20,7 @@ class Pipeline(StatefulBaseClass):
generation_limit: int = 1000,
is_save: bool = False,
save_dir=None,
show_problem_details: bool = False,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -54,6 +55,8 @@ class Pipeline(StatefulBaseClass):
if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir)
self.show_problem_details = show_problem_details
def setup(self, state=State()):
print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed))
@@ -99,6 +102,14 @@ class Pipeline(StatefulBaseClass):
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile()
if self.show_problem_details:
self.compiled_pop_transform_func = (
jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0)))
.lower(self.algorithm.ask(state))
.compile()
)
# compiled_step = self.step
print(
f"compile finished, cost time: {time.time() - tic:.6f}s",
@@ -134,17 +145,20 @@ class Pipeline(StatefulBaseClass):
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
generation = int(state.generation)
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
max_f, min_f, mean_f, std_f = (
max(valid_fitnesses),
min(valid_fitnesses),
np.mean(valid_fitnesses),
np.std(valid_fitnesses),
)
# avoid there is no valid fitness in the whole population
if len(valid_fitnesses) == 0:
max_f, min_f, mean_f, std_f = ["NaN"] * 4
else:
max_f, min_f, mean_f, std_f = (
max(valid_fitnesses),
min(valid_fitnesses),
np.mean(valid_fitnesses),
np.std(valid_fitnesses),
)
new_timestamp = time.time()
@@ -158,9 +172,7 @@ class Pipeline(StatefulBaseClass):
if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
file_name = os.path.join(
self.genome_dir, f"{generation}.npz"
)
file_name = os.path.join(self.genome_dir, f"{generation}.npz")
with open(file_name, "wb") as f:
np.savez(
f,
@@ -171,9 +183,7 @@ class Pipeline(StatefulBaseClass):
# append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
f.write(f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n")
print(
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
@@ -182,6 +192,15 @@ class Pipeline(StatefulBaseClass):
self.algorithm.show_details(state, fitnesses)
if self.show_problem_details:
pop_transformed = self.compiled_pop_transform_func(
state, self.algorithm.ask(state)
)
self.problem.show_details(
state, state.randkey, self.algorithm.forward, pop_transformed
)
# show details for problem
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best)
return self.problem.show(

View File

@@ -33,3 +33,10 @@ class BaseProblem(StatefulBaseClass):
show how a genome perform in this problem
"""
raise NotImplementedError
def show_details(self, state: State, randkey, act_func: Callable, pop_params, *args, **kwargs):
"""
show the running details of the problem
this function will be automaticly call in pipeline.auto_run()
"""
pass