Using Evox to deal with RL tasks! With distributed Gym environment!

Three simple tasks in Gym[classical] are tested.
This commit is contained in:
wls2002
2023-07-04 15:44:08 +08:00
parent c4d34e877b
commit 7bf46575f4
18 changed files with 547 additions and 43 deletions

2
evox_adaptor/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .neat import NEAT
from .gym_no_distribution import Gym

View File

@@ -0,0 +1,83 @@
from typing import Callable
import gym
import jax
import jax.numpy as jnp
import numpy as np
from evox import Problem, State
class Gym(Problem):
def __init__(
self,
pop_size: int,
policy: Callable,
env_name: str = "CartPole-v1",
env_options: dict = None,
batch_policy: bool = True,
):
self.pop_size = pop_size
self.env_name = env_name
self.policy = policy
self.env_options = env_options or {}
self.batch_policy = batch_policy
assert batch_policy, "Only batch policy is supported for now"
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
super().__init__()
def setup(self, key):
return State(key=key)
def evaluate(self, state, pop):
key = state.key
# key, subkey = jax.random.split(state.key)
# generate a list of seeds for gym
# seeds = jax.random.randint(
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
# )
# currently use fixed seed for debugging
seeds = jax.random.randint(
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
)
seeds = seeds.tolist() # seed must be a python int, not numpy array
fitnesses = self.__rollout(seeds, pop)
print("fitnesses info: ")
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
# evox uses negative fitness for minimization
return -fitnesses, State(key=key)
def __rollout(self, seeds, pop):
observations, infos = zip(
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
)
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
fitnesses, rewards = np.zeros((2, self.pop_size))
while not np.all(terminates | truncates):
observations = np.asarray(observations)
actions = self.policy(pop, observations)
actions = jax.device_get(actions)
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
if terminate | truncate:
observation = np.zeros(env.observation_space.shape)
reward = 0
else:
observation, reward, terminate, truncate, info = env.step(action)
observations[i] = observation
rewards[i] = reward
terminates[i] = terminate
truncates[i] = truncate
fitnesses += rewards
return fitnesses

91
evox_adaptor/neat.py Normal file
View File

@@ -0,0 +1,91 @@
import jax.numpy as jnp
import evox
from algorithms import neat
from configs import Configer
@evox.jit_class
class NEAT(evox.Algorithm):
def __init__(self, config):
self.config = config # global config
self.jit_config = Configer.create_jit_config(config)
(
self.randkey,
self.pop_nodes,
self.pop_cons,
self.species_info,
self.idx2species,
self.center_nodes,
self.center_cons,
self.generation,
self.next_node_key,
self.next_species_key,
) = neat.initialize(config)
super().__init__()
def setup(self, key):
return evox.State(
randkey=self.randkey,
pop_nodes=self.pop_nodes,
pop_cons=self.pop_cons,
species_info=self.species_info,
idx2species=self.idx2species,
center_nodes=self.center_nodes,
center_cons=self.center_cons,
generation=self.generation,
next_node_key=self.next_node_key,
next_species_key=self.next_species_key,
jit_config=self.jit_config
)
def ask(self, state):
flatten_pop_nodes = state.pop_nodes.flatten()
flatten_pop_cons = state.pop_cons.flatten()
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
return pop, state
def tell(self, state, fitness):
# evox is a minimization framework, so we need to negate the fitness
fitness = -fitness
(
randkey,
pop_nodes,
pop_cons,
species_info,
idx2species,
center_nodes,
center_cons,
generation,
next_node_key,
next_species_key
) = neat.tell(
fitness,
state.randkey,
state.pop_nodes,
state.pop_cons,
state.species_info,
state.idx2species,
state.center_nodes,
state.center_cons,
state.generation,
state.next_node_key,
state.next_species_key,
state.jit_config
)
return evox.State(
randkey=randkey,
pop_nodes=pop_nodes,
pop_cons=pop_cons,
species_info=species_info,
idx2species=idx2species,
center_nodes=center_nodes,
center_cons=center_cons,
generation=generation,
next_node_key=next_node_key,
next_species_key=next_species_key,
jit_config=state.jit_config
)