add using_multidevice option in pipeline
This commit is contained in:
@@ -22,6 +22,7 @@ class Pipeline(StatefulBaseClass):
|
|||||||
is_save: bool = False,
|
is_save: bool = False,
|
||||||
save_dir=None,
|
save_dir=None,
|
||||||
show_problem_details: bool = False,
|
show_problem_details: bool = False,
|
||||||
|
using_multidevice: bool = False,
|
||||||
):
|
):
|
||||||
assert problem.jitable, "Currently, problem must be jitable"
|
assert problem.jitable, "Currently, problem must be jitable"
|
||||||
|
|
||||||
@@ -58,6 +59,11 @@ class Pipeline(StatefulBaseClass):
|
|||||||
|
|
||||||
self.show_problem_details = show_problem_details
|
self.show_problem_details = show_problem_details
|
||||||
|
|
||||||
|
self.using_multidevice = using_multidevice
|
||||||
|
if self.using_multidevice:
|
||||||
|
assert jax.device_count() > 1, f"using_multidevice requires more than 1 device, but {jax.device_count()=} devices are available"
|
||||||
|
print(f"Using {jax.device_count()} devices!")
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
print("initializing")
|
print("initializing")
|
||||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||||
@@ -86,12 +92,12 @@ class Pipeline(StatefulBaseClass):
|
|||||||
state, pop
|
state, pop
|
||||||
)
|
)
|
||||||
|
|
||||||
if jax.device_count() == 1:
|
if not self.using_multidevice:
|
||||||
keys = jax.random.split(randkey_, self.pop_size)
|
keys = jax.random.split(randkey_, self.pop_size)
|
||||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||||
state, keys, self.algorithm.forward, pop_transformed
|
state, keys, self.algorithm.forward, pop_transformed
|
||||||
)
|
)
|
||||||
else:
|
else: # using_multidevice
|
||||||
num_devices = jax.device_count()
|
num_devices = jax.device_count()
|
||||||
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
|
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
|
||||||
pop_size_per_device = self.pop_size // num_devices
|
pop_size_per_device = self.pop_size // num_devices
|
||||||
|
|||||||
Reference in New Issue
Block a user