Files
tensorneat-mend/evox_adaptor/gym_no_distribution.py

82 lines
2.6 KiB
Python

from typing import Callable
import gym
import jax
import jax.numpy as jnp
import numpy as np
from evox import Problem, State
class Gym(Problem):
def __init__(
self,
pop_size: int,
policy: Callable,
env_name: str = "CartPole-v1",
env_options: dict = None,
batch_policy: bool = True,
):
self.pop_size = pop_size
self.env_name = env_name
self.policy = policy
self.env_options = env_options or {}
self.batch_policy = batch_policy
assert batch_policy, "Only batch policy is supported for now"
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
super().__init__()
def setup(self, key):
return State(key=key)
def evaluate(self, state, pop):
key = state.key
# key, subkey = jax.random.split(state.key)
# generate a list of seeds for gym
# seeds = jax.random.randint(
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
# )
# currently use fixed seed for debugging
seeds = jax.random.randint(
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
)
seeds = seeds.tolist() # seed must be a python int, not numpy array
fitnesses = self.__rollout(seeds, pop)
print("fitnesses info: ")
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
# evox uses negative fitness for minimization
return -fitnesses, State(key=key)
def __rollout(self, seeds, pop):
observations = [env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
fitnesses, rewards = np.zeros((2, self.pop_size))
while not np.all(terminates | truncates):
observations = np.asarray(observations)
actions = self.policy(pop, observations)
actions = jax.device_get(actions)
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
if terminate | truncate:
observation = np.zeros(env.observation_space.shape)
reward = 0
else:
observation, reward, terminate, truncate, info = env.step(action)
observations[i] = observation
rewards[i] = reward
terminates[i] = terminate
truncates[i] = truncate
fitnesses += rewards
return fitnesses