add action_policy for problem;
This commit is contained in:
@@ -19,6 +19,7 @@ class Pipeline:
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
save_path=None,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -55,6 +56,7 @@ class Pipeline:
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
self.save_path = save_path
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -181,6 +183,17 @@ class Pipeline:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
# save best if save path is not None
|
||||
if self.save_path is not None:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(self.save_path, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user