complete forward ways
This commit is contained in:
@@ -97,7 +97,9 @@ def create_forward_function(config):
|
|||||||
|
|
||||||
if config['forward_way'] == 'single':
|
if config['forward_way'] == 'single':
|
||||||
return jit(forward)
|
return jit(forward)
|
||||||
# return jit(batch_forward)
|
|
||||||
|
if config['forward_way'] == 'batch':
|
||||||
|
return jit(batch_forward)
|
||||||
|
|
||||||
elif config['forward_way'] == 'pop':
|
elif config['forward_way'] == 'pop':
|
||||||
return jit(pop_batch_forward)
|
return jit(pop_batch_forward)
|
||||||
|
|||||||
19
pipeline.py
19
pipeline.py
@@ -51,9 +51,14 @@ class Pipeline:
|
|||||||
|
|
||||||
single:
|
single:
|
||||||
Create pop_size number of forward functions.
|
Create pop_size number of forward functions.
|
||||||
Each function receive (batch_size, input_size) and returns (batch_size, output_size)
|
Each function receive (input_size, ) and returns (output_size, )
|
||||||
e.g. RL task
|
e.g. RL task
|
||||||
|
|
||||||
|
batch:
|
||||||
|
Create pop_size number of forward functions.
|
||||||
|
Each function receive (input_size, ) and returns (output_size, )
|
||||||
|
some task need to calculate the fitness of a batch of inputs
|
||||||
|
|
||||||
pop:
|
pop:
|
||||||
Create a single forward function, which use only once calculation for the population.
|
Create a single forward function, which use only once calculation for the population.
|
||||||
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size)
|
||||||
@@ -68,9 +73,19 @@ class Pipeline:
|
|||||||
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
|
pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
# only common mode is supported currently
|
# only common mode is supported currently
|
||||||
assert self.config['forward_way'] == 'common'
|
if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch':
|
||||||
|
# carry data to cpu for fast iteration
|
||||||
|
pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons))
|
||||||
|
funcs = [lambda x: self.forward(x, seqs, nodes, u_cons)
|
||||||
|
for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)]
|
||||||
|
return funcs
|
||||||
|
|
||||||
|
elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common':
|
||||||
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported")
|
||||||
|
|
||||||
def tell(self, fitness):
|
def tell(self, fitness):
|
||||||
(
|
(
|
||||||
self.randkey,
|
self.randkey,
|
||||||
|
|||||||
Reference in New Issue
Block a user