diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index f85516c..0ddf3ed 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -22,6 +22,7 @@ class Pipeline(StatefulBaseClass): is_save: bool = False, save_dir=None, show_problem_details: bool = False, + using_multidevice: bool = False, ): assert problem.jitable, "Currently, problem must be jitable" @@ -58,6 +59,11 @@ class Pipeline(StatefulBaseClass): self.show_problem_details = show_problem_details + self.using_multidevice = using_multidevice + if self.using_multidevice: + assert jax.device_count() > 1, f"using_multidevice requires more than 1 device, but {jax.device_count()=} devices are available" + print(f"Using {jax.device_count()} devices!") + def setup(self, state=State()): print("initializing") state = state.register(randkey=jax.random.PRNGKey(self.seed)) @@ -86,12 +92,12 @@ class Pipeline(StatefulBaseClass): state, pop ) - if jax.device_count() == 1: + if not self.using_multidevice: keys = jax.random.split(randkey_, self.pop_size) fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( state, keys, self.algorithm.forward, pop_transformed ) - else: + else: # using_multidevice num_devices = jax.device_count() assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()" pop_size_per_device = self.pop_size // num_devices