Merge pull request #20 from WhymustIhaveaname/main
fix show error of cartpolev1; add multi device support to pipeline
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
import datetime, time
|
import datetime, time
|
||||||
@@ -78,7 +79,6 @@ class Pipeline(StatefulBaseClass):
|
|||||||
def step(self, state):
|
def step(self, state):
|
||||||
|
|
||||||
randkey_, randkey = jax.random.split(state.randkey)
|
randkey_, randkey = jax.random.split(state.randkey)
|
||||||
keys = jax.random.split(randkey_, self.pop_size)
|
|
||||||
|
|
||||||
pop = self.algorithm.ask(state)
|
pop = self.algorithm.ask(state)
|
||||||
|
|
||||||
@@ -86,9 +86,31 @@ class Pipeline(StatefulBaseClass):
|
|||||||
state, pop
|
state, pop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if jax.device_count() == 1:
|
||||||
|
keys = jax.random.split(randkey_, self.pop_size)
|
||||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||||
state, keys, self.algorithm.forward, pop_transformed
|
state, keys, self.algorithm.forward, pop_transformed
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
num_devices = jax.device_count()
|
||||||
|
assert self.pop_size % num_devices == 0, "if you want to use multiple gpus, pop_size must be divisible by jax.device_count()"
|
||||||
|
pop_size_per_device = self.pop_size // num_devices
|
||||||
|
|
||||||
|
keys = jax.random.split(randkey_, (num_devices, pop_size_per_device))
|
||||||
|
split_pop_transformed = jax.tree_map(
|
||||||
|
lambda x: x.reshape(num_devices, pop_size_per_device, *x.shape[1:]),
|
||||||
|
pop_transformed
|
||||||
|
)
|
||||||
|
|
||||||
|
fitnesses = jax.pmap(
|
||||||
|
lambda key_slice, pop_slice: jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||||
|
state, key_slice, self.algorithm.forward, pop_slice
|
||||||
|
),
|
||||||
|
axis_name='devices',
|
||||||
|
in_axes=(0, 0)
|
||||||
|
)(keys, split_pop_transformed)
|
||||||
|
|
||||||
|
fitnesses = fitnesses.reshape(self.pop_size)
|
||||||
|
|
||||||
# replace nan with -inf
|
# replace nan with -inf
|
||||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||||
@@ -101,6 +123,10 @@ class Pipeline(StatefulBaseClass):
|
|||||||
def auto_run(self, state):
|
def auto_run(self, state):
|
||||||
print("start compile")
|
print("start compile")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore",
|
||||||
|
message=r"The jitted function .* includes a pmap. Using jit-of-pmap can lead to inefficient data movement"
|
||||||
|
)
|
||||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||||
|
|
||||||
if self.show_problem_details:
|
if self.show_problem_details:
|
||||||
@@ -110,7 +136,6 @@ class Pipeline(StatefulBaseClass):
|
|||||||
.compile()
|
.compile()
|
||||||
)
|
)
|
||||||
|
|
||||||
# compiled_step = self.step
|
|
||||||
print(
|
print(
|
||||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,9 +79,14 @@ class BraxEnv(RLEnv):
|
|||||||
|
|
||||||
print("Total reward: ", reward)
|
print("Total reward: ", reward)
|
||||||
|
|
||||||
|
try:
|
||||||
imgs = image.render_array(
|
imgs = image.render_array(
|
||||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
sys=self.env.sys, trajectory=state_histories, height=height, width=width, camera="track"
|
||||||
)
|
)
|
||||||
|
except ValueError:
|
||||||
|
imgs = image.render_array(
|
||||||
|
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||||
|
)
|
||||||
|
|
||||||
if output_type == "rgb_array":
|
if output_type == "rgb_array":
|
||||||
imgs = np.array(imgs)
|
imgs = np.array(imgs)
|
||||||
|
|||||||
Reference in New Issue
Block a user