Merge pull request #20 from WhymustIhaveaname/main
fix show error of cartpolev1; add multi device support to pipeline
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import datetime, time
|
||||
@@ -78,7 +79,6 @@ class Pipeline(StatefulBaseClass):
|
||||
def step(self, state):
|
||||
|
||||
randkey_, randkey = jax.random.split(state.randkey)
|
||||
keys = jax.random.split(randkey_, self.pop_size)
|
||||
|
||||
pop = self.algorithm.ask(state)
|
||||
|
||||
@@ -86,9 +86,31 @@ class Pipeline(StatefulBaseClass):
|
||||
state, pop
|
||||
)
|
||||
|
||||
if jax.device_count() == 1:
|
||||
keys = jax.random.split(randkey_, self.pop_size)
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
else:
|
||||
num_devices = jax.device_count()
|
||||
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
|
||||
pop_size_per_device = self.pop_size // num_devices
|
||||
|
||||
keys = jax.random.split(randkey_, (num_devices, pop_size_per_device))
|
||||
split_pop_transformed = jax.tree_map(
|
||||
lambda x: x.reshape(num_devices, pop_size_per_device, *x.shape[1:]),
|
||||
pop_transformed
|
||||
)
|
||||
|
||||
fitnesses = jax.pmap(
|
||||
lambda key_slice, pop_slice: jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, key_slice, self.algorithm.forward, pop_slice
|
||||
),
|
||||
axis_name='devices',
|
||||
in_axes=(0, 0)
|
||||
)(keys, split_pop_transformed)
|
||||
|
||||
fitnesses = fitnesses.reshape(self.pop_size)
|
||||
|
||||
# replace nan with -inf
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||
@@ -101,6 +123,10 @@ class Pipeline(StatefulBaseClass):
|
||||
def auto_run(self, state):
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore",
|
||||
message=r"The jitted function .* includes a pmap. Using jit-of-pmap can lead to inefficient data movement"
|
||||
)
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
|
||||
if self.show_problem_details:
|
||||
@@ -110,7 +136,6 @@ class Pipeline(StatefulBaseClass):
|
||||
.compile()
|
||||
)
|
||||
|
||||
# compiled_step = self.step
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
)
|
||||
|
||||
@@ -79,9 +79,14 @@ class BraxEnv(RLEnv):
|
||||
|
||||
print("Total reward: ", reward)
|
||||
|
||||
try:
|
||||
imgs = image.render_array(
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
||||
)
|
||||
except ValueError:
|
||||
imgs = image.render_array(
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||
)
|
||||
|
||||
if output_type == "rgb_array":
|
||||
imgs = np.array(imgs)
|
||||
|
||||
Reference in New Issue
Block a user