complete forward ways
This commit is contained in:
@@ -97,7 +97,9 @@ def create_forward_function(config):
|
||||
|
||||
if config['forward_way'] == 'single':
|
||||
return jit(forward)
|
||||
# return jit(batch_forward)
|
||||
|
||||
if config['forward_way'] == 'batch':
|
||||
return jit(batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'pop':
|
||||
return jit(pop_batch_forward)
|
||||
|
||||
21
pipeline.py
21
pipeline.py
@@ -51,9 +51,14 @@ class Pipeline:
|
||||
|
||||
single:
|
||||
Create pop_size number of forward functions.
|
||||
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
||||
Each function receive (input_size, ) and returns (output_size, )
|
||||
e.g. RL task
|
||||
|
||||
batch:
|
||||
Create pop_size number of forward functions.
|
||||
Each function receive (input_size, ) and returns (output_size, )
|
||||
some task need to calculate the fitness of a batch of inputs
|
||||
|
||||
pop:
|
||||
Create a single forward function, which use only once calculation for the population.
|
||||
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||
@@ -68,8 +73,18 @@ class Pipeline:
|
||||
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
|
||||
|
||||
# only common mode is supported currently
|
||||
assert self.config['forward_way'] == 'common'
|
||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch':
|
||||
# carry data to cpu for fast iteration
|
||||
pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons))
|
||||
funcs = [lambda x: self.forward(x, seqs, nodes, u_cons)
|
||||
for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)]
|
||||
return funcs
|
||||
|
||||
elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common':
|
||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported")
|
||||
|
||||
def tell(self, fitness):
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user