Using Evox to deal with RL tasks! With distributed Gym environment!
Three simple tasks in Gym[classical] are tested.
This commit is contained in:
83
evox_adaptor/gym_no_distribution.py
Normal file
83
evox_adaptor/gym_no_distribution.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Callable
|
||||
|
||||
import gym
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from evox import Problem, State
|
||||
|
||||
|
||||
class Gym(Problem):
|
||||
def __init__(
|
||||
self,
|
||||
pop_size: int,
|
||||
policy: Callable,
|
||||
env_name: str = "CartPole-v1",
|
||||
env_options: dict = None,
|
||||
batch_policy: bool = True,
|
||||
):
|
||||
self.pop_size = pop_size
|
||||
self.env_name = env_name
|
||||
self.policy = policy
|
||||
self.env_options = env_options or {}
|
||||
self.batch_policy = batch_policy
|
||||
assert batch_policy, "Only batch policy is supported for now"
|
||||
|
||||
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
|
||||
|
||||
super().__init__()
|
||||
|
||||
def setup(self, key):
|
||||
return State(key=key)
|
||||
|
||||
def evaluate(self, state, pop):
|
||||
key = state.key
|
||||
# key, subkey = jax.random.split(state.key)
|
||||
|
||||
# generate a list of seeds for gym
|
||||
# seeds = jax.random.randint(
|
||||
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||
# )
|
||||
|
||||
# currently use fixed seed for debugging
|
||||
seeds = jax.random.randint(
|
||||
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||
)
|
||||
|
||||
seeds = seeds.tolist() # seed must be a python int, not numpy array
|
||||
|
||||
fitnesses = self.__rollout(seeds, pop)
|
||||
print("fitnesses info: ")
|
||||
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
|
||||
|
||||
# evox uses negative fitness for minimization
|
||||
return -fitnesses, State(key=key)
|
||||
|
||||
def __rollout(self, seeds, pop):
|
||||
observations, infos = zip(
|
||||
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
|
||||
)
|
||||
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
|
||||
fitnesses, rewards = np.zeros((2, self.pop_size))
|
||||
|
||||
while not np.all(terminates | truncates):
|
||||
observations = np.asarray(observations)
|
||||
actions = self.policy(pop, observations)
|
||||
actions = jax.device_get(actions)
|
||||
|
||||
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
|
||||
if terminate | truncate:
|
||||
observation = np.zeros(env.observation_space.shape)
|
||||
reward = 0
|
||||
else:
|
||||
observation, reward, terminate, truncate, info = env.step(action)
|
||||
|
||||
observations[i] = observation
|
||||
rewards[i] = reward
|
||||
terminates[i] = terminate
|
||||
truncates[i] = truncate
|
||||
|
||||
fitnesses += rewards
|
||||
|
||||
return fitnesses
|
||||
Reference in New Issue
Block a user