complete forward ways

This commit is contained in:
wls2002
2023-07-08 20:52:31 +08:00
parent 6670c4aea1
commit 7265e33c43
2 changed files with 21 additions and 4 deletions

View File

@@ -97,7 +97,9 @@ def create_forward_function(config):
if config['forward_way'] == 'single':
return jit(forward)
# return jit(batch_forward)
if config['forward_way'] == 'batch':
return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)

View File

@@ -51,9 +51,14 @@ class Pipeline:
single:
Create pop_size number of forward functions.
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
Each function receive (input_size, ) and returns (output_size, )
e.g. RL task
batch:
Create pop_size number of forward functions.
Each function receive (input_size, ) and returns (output_size, )
some task need to calculate the fitness of a batch of inputs
pop:
Create a single forward function, which use only once calculation for the population.
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
@@ -68,8 +73,18 @@ class Pipeline:
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
# only common mode is supported currently
assert self.config['forward_way'] == 'common'
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch':
# carry data to cpu for fast iteration
pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons))
funcs = [lambda x: self.forward(x, seqs, nodes, u_cons)
for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)]
return funcs
elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common':
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
else:
raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported")
def tell(self, fitness):
(