Merge pull request #1 from WLS2002/full-jit

Strange bug! Add {"new_step_api": True} in gym environments.
This commit is contained in:
WLS2002
2023-07-05 15:40:11 +08:00
committed by GitHub
5 changed files with 18 additions and 3 deletions

View File

@@ -55,9 +55,7 @@ class Gym(Problem):
return -fitnesses, State(key=key)
def __rollout(self, seeds, pop):
observations, infos = zip(
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
)
observations = [env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
fitnesses, rewards = np.zeros((2, self.pop_size))

View File

@@ -39,6 +39,7 @@ if __name__ == '__main__':
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="Acrobot-v1",
env_options={"new_step_api": True},
pop_size=100,
)

View File

@@ -39,6 +39,7 @@ if __name__ == '__main__':
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="CartPole-v1",
env_options={"new_step_api": True},
pop_size=40,
)

View File

@@ -0,0 +1,14 @@
import gym
env = gym.make("CartPole-v1", new_step_api=True)
print(env.reset())
obs = env.reset()
print(obs)
while True:
action = env.action_space.sample()
obs, reward, terminate, truncate, info = env.step(action)
print(obs, info)
if terminate | truncate:
break

View File

@@ -39,6 +39,7 @@ if __name__ == '__main__':
problem = Gym(
policy=jit(vmap(neat_forward)),
env_name="MountainCarContinuous-v0",
env_options={"new_step_api": True},
pop_size=100,
)