Merge pull request #20 from WhymustIhaveaname/main

fix show error of cartpolev1; add multi device support to pipeline
This commit is contained in:
WLS2002
2025-02-18 16:44:58 +08:00
committed by GitHub
2 changed files with 41 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
import json
import os
import warnings
import jax, jax.numpy as jnp
import datetime, time
@@ -78,7 +79,6 @@ class Pipeline(StatefulBaseClass):
def step(self, state):
randkey_, randkey = jax.random.split(state.randkey)
keys = jax.random.split(randkey_, self.pop_size)
pop = self.algorithm.ask(state)
@@ -86,9 +86,31 @@ class Pipeline(StatefulBaseClass):
state, pop
)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
if jax.device_count() == 1:
keys = jax.random.split(randkey_, self.pop_size)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
else:
num_devices = jax.device_count()
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
pop_size_per_device = self.pop_size // num_devices
keys = jax.random.split(randkey_, (num_devices, pop_size_per_device))
split_pop_transformed = jax.tree_map(
lambda x: x.reshape(num_devices, pop_size_per_device, *x.shape[1:]),
pop_transformed
)
fitnesses = jax.pmap(
lambda key_slice, pop_slice: jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, key_slice, self.algorithm.forward, pop_slice
),
axis_name='devices',
in_axes=(0, 0)
)(keys, split_pop_transformed)
fitnesses = fitnesses.reshape(self.pop_size)
# replace nan with -inf
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
@@ -101,7 +123,11 @@ class Pipeline(StatefulBaseClass):
def auto_run(self, state):
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile()
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"The jitted function .* includes a pmap. Using jit-of-pmap can lead to inefficient data movement"
)
compiled_step = jax.jit(self.step).lower(state).compile()
if self.show_problem_details:
self.compiled_pop_transform_func = (
@@ -110,7 +136,6 @@ class Pipeline(StatefulBaseClass):
.compile()
)
# compiled_step = self.step
print(
f"compile finished, cost time: {time.time() - tic:.6f}s",
)

View File

@@ -79,9 +79,14 @@ class BraxEnv(RLEnv):
print("Total reward: ", reward)
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
try:
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
)
except ValueError:
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width
)
if output_type == "rgb_array":
imgs = np.array(imgs)