add multi device support to pipeline

This commit is contained in:
Youran on Lambda
2025-02-17 23:06:49 +08:00
parent f17f31bb2a
commit 626afcdf75

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
import warnings
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import datetime, time import datetime, time
@@ -78,7 +79,6 @@ class Pipeline(StatefulBaseClass):
def step(self, state): def step(self, state):
randkey_, randkey = jax.random.split(state.randkey) randkey_, randkey = jax.random.split(state.randkey)
keys = jax.random.split(randkey_, self.pop_size)
pop = self.algorithm.ask(state) pop = self.algorithm.ask(state)
@@ -86,9 +86,31 @@ class Pipeline(StatefulBaseClass):
state, pop state, pop
) )
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( if jax.device_count() == 1:
state, keys, self.algorithm.forward, pop_transformed keys = jax.random.split(randkey_, self.pop_size)
) fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
else:
num_devices = jax.device_count()
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
pop_size_per_device = self.pop_size // num_devices
keys = jax.random.split(randkey_, (num_devices, pop_size_per_device))
split_pop_transformed = jax.tree_map(
lambda x: x.reshape(num_devices, pop_size_per_device, *x.shape[1:]),
pop_transformed
)
fitnesses = jax.pmap(
lambda key_slice, pop_slice: jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, key_slice, self.algorithm.forward, pop_slice
),
axis_name='devices',
in_axes=(0, 0)
)(keys, split_pop_transformed)
fitnesses = fitnesses.reshape(self.pop_size)
# replace nan with -inf # replace nan with -inf
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
@@ -101,7 +123,11 @@ class Pipeline(StatefulBaseClass):
def auto_run(self, state): def auto_run(self, state):
print("start compile") print("start compile")
tic = time.time() tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile() with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"The jitted function .* includes a pmap. Using jit-of-pmap can lead to inefficient data movement"
)
compiled_step = jax.jit(self.step).lower(state).compile()
if self.show_problem_details: if self.show_problem_details:
self.compiled_pop_transform_func = ( self.compiled_pop_transform_func = (
@@ -110,7 +136,6 @@ class Pipeline(StatefulBaseClass):
.compile() .compile()
) )
# compiled_step = self.step
print( print(
f"compile finished, cost time: {time.time() - tic:.6f}s", f"compile finished, cost time: {time.time() - tic:.6f}s",
) )