diff --git a/algorithms/neat/genome/forward.py b/algorithms/neat/genome/forward.py index bc37bcb..2d95973 100644 --- a/algorithms/neat/genome/forward.py +++ b/algorithms/neat/genome/forward.py @@ -97,7 +97,9 @@ def create_forward_function(config): if config['forward_way'] == 'single': return jit(forward) - # return jit(batch_forward) + + if config['forward_way'] == 'batch': + return jit(batch_forward) elif config['forward_way'] == 'pop': return jit(pop_batch_forward) diff --git a/pipeline.py b/pipeline.py index 46b0a47..6d0b18a 100644 --- a/pipeline.py +++ b/pipeline.py @@ -51,9 +51,14 @@ class Pipeline: single: Create pop_size number of forward functions. - Each function receive (batch_size, input_size) and returns (batch_size, output_size) + Each function receive (input_size, ) and returns (output_size, ) e.g. RL task + batch: + Create pop_size number of forward functions. + Each function receive (input_size, ) and returns (output_size, ) + some task need to calculate the fitness of a batch of inputs + pop: Create a single forward function, which use only once calculation for the population. The function receives (pop_size, batch_size, input_size) and returns (pop_size, batch_size, output_size) @@ -68,8 +73,18 @@ class Pipeline: pop_seqs = self.pop_topological_sort(self.pop_nodes, u_pop_cons) # only common mode is supported currently - assert self.config['forward_way'] == 'common' - return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) + if self.config['forward_way'] == 'single' or self.config['forward_way'] == 'batch': + # carry data to cpu for fast iteration + pop_seqs, self.pop_nodes, self.pop_cons = jax.device_get((pop_seqs, self.pop_nodes, self.pop_cons)) + funcs = [lambda x: self.forward(x, seqs, nodes, u_cons) + for seqs, nodes, u_cons in zip(pop_seqs, self.pop_nodes, self.pop_cons)] + return funcs + + elif self.config['forward_way'] == 'pop' or self.config['forward_way'] == 'common': + return lambda x: self.forward(x, pop_seqs, self.pop_nodes, u_pop_cons) + + else: + raise NotImplementedError(f"forward_way {self.config['forward_way']} is not supported") def tell(self, fitness): (