remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -27,15 +27,15 @@ class Pipeline:
self.evaluate_time = 0
self.forward_func = jit(self.algorithm.forward)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
self.act_func = jit(self.algorithm.act)
self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell)
def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
def tell(self, fitness):
# self.state = self.tell_func(self.state, fitness)
@@ -80,8 +80,4 @@ class Pipeline:
print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")