remove create_func....
This commit is contained in:
14
pipeline.py
14
pipeline.py
@@ -27,15 +27,15 @@ class Pipeline:
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
self.forward_func = jit(self.algorithm.forward)
|
||||
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
|
||||
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
|
||||
self.act_func = jit(self.algorithm.act)
|
||||
self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
|
||||
self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
|
||||
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
|
||||
self.tell_func = jit(self.algorithm.tell)
|
||||
|
||||
def ask(self):
|
||||
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
|
||||
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
|
||||
return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
|
||||
|
||||
def tell(self, fitness):
|
||||
# self.state = self.tell_func(self.state, fitness)
|
||||
@@ -80,8 +80,4 @@ class Pipeline:
|
||||
|
||||
print(f"Generation: {self.state.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||
|
||||
|
||||
|
||||
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||
Reference in New Issue
Block a user