From 2cc72bb188e56f3611e8423ab1d302a2a0c32fb6 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 5 Jul 2023 15:39:22 +0800 Subject: [PATCH] Strange bug! Add {"new_step_api": True} in gym environments. --- evox_adaptor/gym_no_distribution.py | 4 +--- examples/evox_/acrobot.py | 1 + examples/evox_/cartpole.py | 1 + examples/evox_/gym_env_test.py | 14 ++++++++++++++ examples/evox_/mountain_car.py | 1 + 5 files changed, 18 insertions(+), 3 deletions(-) create mode 100644 examples/evox_/gym_env_test.py diff --git a/evox_adaptor/gym_no_distribution.py b/evox_adaptor/gym_no_distribution.py index ad7365f..4a30b0c 100644 --- a/evox_adaptor/gym_no_distribution.py +++ b/evox_adaptor/gym_no_distribution.py @@ -55,9 +55,7 @@ class Gym(Problem): return -fitnesses, State(key=key) def __rollout(self, seeds, pop): - observations, infos = zip( - *[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)] - ) + observations = [env.reset(seed=seed) for env, seed in zip(self.envs, seeds)] terminates, truncates = np.zeros((2, self.pop_size), dtype=bool) fitnesses, rewards = np.zeros((2, self.pop_size)) diff --git a/examples/evox_/acrobot.py b/examples/evox_/acrobot.py index f96dd22..d0b7fa8 100644 --- a/examples/evox_/acrobot.py +++ b/examples/evox_/acrobot.py @@ -39,6 +39,7 @@ if __name__ == '__main__': problem = Gym( policy=jit(vmap(neat_forward)), env_name="Acrobot-v1", + env_options={"new_step_api": True}, pop_size=100, ) diff --git a/examples/evox_/cartpole.py b/examples/evox_/cartpole.py index cec596b..54c73ec 100644 --- a/examples/evox_/cartpole.py +++ b/examples/evox_/cartpole.py @@ -39,6 +39,7 @@ if __name__ == '__main__': problem = Gym( policy=jit(vmap(neat_forward)), env_name="CartPole-v1", + env_options={"new_step_api": True}, pop_size=40, ) diff --git a/examples/evox_/gym_env_test.py b/examples/evox_/gym_env_test.py new file mode 100644 index 0000000..ef5cefa --- /dev/null +++ b/examples/evox_/gym_env_test.py @@ -0,0 +1,14 @@ +import gym + +env = gym.make("CartPole-v1", new_step_api=True) +print(env.reset()) +obs = env.reset() + +print(obs) +while True: + action = env.action_space.sample() + obs, reward, terminate, truncate, info = env.step(action) + print(obs, info) + if terminate | truncate: + break + diff --git a/examples/evox_/mountain_car.py b/examples/evox_/mountain_car.py index 9fcd66f..9d8bea1 100644 --- a/examples/evox_/mountain_car.py +++ b/examples/evox_/mountain_car.py @@ -39,6 +39,7 @@ if __name__ == '__main__': problem = Gym( policy=jit(vmap(neat_forward)), env_name="MountainCarContinuous-v0", + env_options={"new_step_api": True}, pop_size=100, )