add action_policy for problem;

This commit is contained in:
wls2002
2024-06-07 17:09:16 +08:00
parent 10ec1c2df9
commit 3d5b80c6fa
13 changed files with 2417 additions and 1191 deletions

View File

@@ -19,6 +19,7 @@ class Pipeline:
generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
save_path=None,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -55,6 +56,7 @@ class Pipeline:
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
self.save_path = save_path
def setup(self, state=State()):
print("initializing")
@@ -181,6 +183,17 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
# save best if save path is not None
if self.save_path is not None:
best_genome = jax.device_get(self.best_genome)
with open(self.save_path, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]