84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
from typing import Callable
|
|
|
|
import gym
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
from evox import Problem, State
|
|
|
|
|
|
class Gym(Problem):
|
|
def __init__(
|
|
self,
|
|
pop_size: int,
|
|
policy: Callable,
|
|
env_name: str = "CartPole-v1",
|
|
env_options: dict = None,
|
|
batch_policy: bool = True,
|
|
):
|
|
self.pop_size = pop_size
|
|
self.env_name = env_name
|
|
self.policy = policy
|
|
self.env_options = env_options or {}
|
|
self.batch_policy = batch_policy
|
|
assert batch_policy, "Only batch policy is supported for now"
|
|
|
|
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
|
|
|
|
super().__init__()
|
|
|
|
def setup(self, key):
|
|
return State(key=key)
|
|
|
|
def evaluate(self, state, pop):
|
|
key = state.key
|
|
# key, subkey = jax.random.split(state.key)
|
|
|
|
# generate a list of seeds for gym
|
|
# seeds = jax.random.randint(
|
|
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
|
# )
|
|
|
|
# currently use fixed seed for debugging
|
|
seeds = jax.random.randint(
|
|
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
|
)
|
|
|
|
seeds = seeds.tolist() # seed must be a python int, not numpy array
|
|
|
|
fitnesses = self.__rollout(seeds, pop)
|
|
print("fitnesses info: ")
|
|
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
|
|
|
|
# evox uses negative fitness for minimization
|
|
return -fitnesses, State(key=key)
|
|
|
|
def __rollout(self, seeds, pop):
|
|
observations, infos = zip(
|
|
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
|
|
)
|
|
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
|
|
fitnesses, rewards = np.zeros((2, self.pop_size))
|
|
|
|
while not np.all(terminates | truncates):
|
|
observations = np.asarray(observations)
|
|
actions = self.policy(pop, observations)
|
|
actions = jax.device_get(actions)
|
|
|
|
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
|
|
if terminate | truncate:
|
|
observation = np.zeros(env.observation_space.shape)
|
|
reward = 0
|
|
else:
|
|
observation, reward, terminate, truncate, info = env.step(action)
|
|
|
|
observations[i] = observation
|
|
rewards[i] = reward
|
|
terminates[i] = terminate
|
|
truncates[i] = truncate
|
|
|
|
fitnesses += rewards
|
|
|
|
return fitnesses
|