Files
tensorneat-mend/tensorneat/examples/brax/show_test.py
2024-06-16 22:28:28 +08:00

20 lines
507 B
Python

import jax
from problem.rl_env import BraxEnv
def random_policy(randkey, forward_func, obs):
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
if __name__ == "__main__":
problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy)
state = problem.setup()
randkey = jax.random.key(0)
problem.show(
state,
randkey,
act_func=lambda state, params, obs: obs,
params=None,
save_path="walker2d_random_policy",
)