diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index 86ee90f..f85516c 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -1,5 +1,6 @@ import json import os +import warnings import jax, jax.numpy as jnp import datetime, time @@ -78,7 +79,6 @@ class Pipeline(StatefulBaseClass): def step(self, state): randkey_, randkey = jax.random.split(state.randkey) - keys = jax.random.split(randkey_, self.pop_size) pop = self.algorithm.ask(state) @@ -86,9 +86,31 @@ class Pipeline(StatefulBaseClass): state, pop ) - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( - state, keys, self.algorithm.forward, pop_transformed - ) + if jax.device_count() == 1: + keys = jax.random.split(randkey_, self.pop_size) + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( + state, keys, self.algorithm.forward, pop_transformed + ) + else: + num_devices = jax.device_count() + assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()" + pop_size_per_device = self.pop_size // num_devices + + keys = jax.random.split(randkey_, (num_devices, pop_size_per_device)) + split_pop_transformed = jax.tree_map( + lambda x: x.reshape(num_devices, pop_size_per_device, *x.shape[1:]), + pop_transformed + ) + + fitnesses = jax.pmap( + lambda key_slice, pop_slice: jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( + state, key_slice, self.algorithm.forward, pop_slice + ), + axis_name='devices', + in_axes=(0, 0) + )(keys, split_pop_transformed) + + fitnesses = fitnesses.reshape(self.pop_size) # replace nan with -inf fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) @@ -101,7 +123,11 @@ class Pipeline(StatefulBaseClass): def auto_run(self, state): print("start compile") tic = time.time() - compiled_step = jax.jit(self.step).lower(state).compile() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message=r"The jitted function .* includes a pmap. Using jit-of-pmap can lead to inefficient data movement" + ) + compiled_step = jax.jit(self.step).lower(state).compile() if self.show_problem_details: self.compiled_pop_transform_func = ( @@ -110,7 +136,6 @@ class Pipeline(StatefulBaseClass): .compile() ) - # compiled_step = self.step print( f"compile finished, cost time: {time.time() - tic:.6f}s", ) diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index c7b2942..cf6e84d 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -41,7 +41,7 @@ class BraxEnv(RLEnv): *args, **kwargs, ): - + assert output_type in ["rgb_array", "gif"] import jax @@ -76,12 +76,17 @@ class BraxEnv(RLEnv): reward += r if done: break - + print("Total reward: ", reward) - imgs = image.render_array( - sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" - ) + try: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track" + ) + except ValueError: + imgs = image.render_array( + sys=self.env.sys, trajectory=state_histories, height=height, width=width + ) if output_type == "rgb_array": imgs = np.array(imgs)