Using Evox to deal with RL tasks! With distributed Gym environment!
Three simple tasks in Gym[classical] are tested.
This commit is contained in:
2
evox_adaptor/__init__.py
Normal file
2
evox_adaptor/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .neat import NEAT
|
||||
from .gym_no_distribution import Gym
|
||||
83
evox_adaptor/gym_no_distribution.py
Normal file
83
evox_adaptor/gym_no_distribution.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Callable
|
||||
|
||||
import gym
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from evox import Problem, State
|
||||
|
||||
|
||||
class Gym(Problem):
|
||||
def __init__(
|
||||
self,
|
||||
pop_size: int,
|
||||
policy: Callable,
|
||||
env_name: str = "CartPole-v1",
|
||||
env_options: dict = None,
|
||||
batch_policy: bool = True,
|
||||
):
|
||||
self.pop_size = pop_size
|
||||
self.env_name = env_name
|
||||
self.policy = policy
|
||||
self.env_options = env_options or {}
|
||||
self.batch_policy = batch_policy
|
||||
assert batch_policy, "Only batch policy is supported for now"
|
||||
|
||||
self.envs = [gym.make(env_name, **self.env_options) for _ in range(self.pop_size)]
|
||||
|
||||
super().__init__()
|
||||
|
||||
def setup(self, key):
|
||||
return State(key=key)
|
||||
|
||||
def evaluate(self, state, pop):
|
||||
key = state.key
|
||||
# key, subkey = jax.random.split(state.key)
|
||||
|
||||
# generate a list of seeds for gym
|
||||
# seeds = jax.random.randint(
|
||||
# subkey, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||
# )
|
||||
|
||||
# currently use fixed seed for debugging
|
||||
seeds = jax.random.randint(
|
||||
key, (self.pop_size,), 0, jnp.iinfo(jnp.int32).max
|
||||
)
|
||||
|
||||
seeds = seeds.tolist() # seed must be a python int, not numpy array
|
||||
|
||||
fitnesses = self.__rollout(seeds, pop)
|
||||
print("fitnesses info: ")
|
||||
print(f"max: {np.max(fitnesses)}, min: {np.min(fitnesses)}, mean: {np.mean(fitnesses)}, std: {np.std(fitnesses)}")
|
||||
|
||||
# evox uses negative fitness for minimization
|
||||
return -fitnesses, State(key=key)
|
||||
|
||||
def __rollout(self, seeds, pop):
|
||||
observations, infos = zip(
|
||||
*[env.reset(seed=seed) for env, seed in zip(self.envs, seeds)]
|
||||
)
|
||||
terminates, truncates = np.zeros((2, self.pop_size), dtype=bool)
|
||||
fitnesses, rewards = np.zeros((2, self.pop_size))
|
||||
|
||||
while not np.all(terminates | truncates):
|
||||
observations = np.asarray(observations)
|
||||
actions = self.policy(pop, observations)
|
||||
actions = jax.device_get(actions)
|
||||
|
||||
for i, (action, terminate, truncate, env) in enumerate(zip(actions, terminates, truncates, self.envs)):
|
||||
if terminate | truncate:
|
||||
observation = np.zeros(env.observation_space.shape)
|
||||
reward = 0
|
||||
else:
|
||||
observation, reward, terminate, truncate, info = env.step(action)
|
||||
|
||||
observations[i] = observation
|
||||
rewards[i] = reward
|
||||
terminates[i] = terminate
|
||||
truncates[i] = truncate
|
||||
|
||||
fitnesses += rewards
|
||||
|
||||
return fitnesses
|
||||
91
evox_adaptor/neat.py
Normal file
91
evox_adaptor/neat.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
import evox
|
||||
from algorithms import neat
|
||||
from configs import Configer
|
||||
|
||||
|
||||
@evox.jit_class
|
||||
class NEAT(evox.Algorithm):
|
||||
def __init__(self, config):
|
||||
self.config = config # global config
|
||||
self.jit_config = Configer.create_jit_config(config)
|
||||
(
|
||||
self.randkey,
|
||||
self.pop_nodes,
|
||||
self.pop_cons,
|
||||
self.species_info,
|
||||
self.idx2species,
|
||||
self.center_nodes,
|
||||
self.center_cons,
|
||||
self.generation,
|
||||
self.next_node_key,
|
||||
self.next_species_key,
|
||||
) = neat.initialize(config)
|
||||
super().__init__()
|
||||
|
||||
def setup(self, key):
|
||||
return evox.State(
|
||||
randkey=self.randkey,
|
||||
pop_nodes=self.pop_nodes,
|
||||
pop_cons=self.pop_cons,
|
||||
species_info=self.species_info,
|
||||
idx2species=self.idx2species,
|
||||
center_nodes=self.center_nodes,
|
||||
center_cons=self.center_cons,
|
||||
generation=self.generation,
|
||||
next_node_key=self.next_node_key,
|
||||
next_species_key=self.next_species_key,
|
||||
jit_config=self.jit_config
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
flatten_pop_nodes = state.pop_nodes.flatten()
|
||||
flatten_pop_cons = state.pop_cons.flatten()
|
||||
pop = jnp.concatenate([flatten_pop_nodes, flatten_pop_cons])
|
||||
return pop, state
|
||||
|
||||
def tell(self, state, fitness):
|
||||
|
||||
# evox is a minimization framework, so we need to negate the fitness
|
||||
fitness = -fitness
|
||||
|
||||
(
|
||||
randkey,
|
||||
pop_nodes,
|
||||
pop_cons,
|
||||
species_info,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_cons,
|
||||
generation,
|
||||
next_node_key,
|
||||
next_species_key
|
||||
) = neat.tell(
|
||||
fitness,
|
||||
state.randkey,
|
||||
state.pop_nodes,
|
||||
state.pop_cons,
|
||||
state.species_info,
|
||||
state.idx2species,
|
||||
state.center_nodes,
|
||||
state.center_cons,
|
||||
state.generation,
|
||||
state.next_node_key,
|
||||
state.next_species_key,
|
||||
state.jit_config
|
||||
)
|
||||
|
||||
return evox.State(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_cons=pop_cons,
|
||||
species_info=species_info,
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_cons=center_cons,
|
||||
generation=generation,
|
||||
next_node_key=next_node_key,
|
||||
next_species_key=next_species_key,
|
||||
jit_config=state.jit_config
|
||||
)
|
||||
Reference in New Issue
Block a user