update problem and pipeline
This commit is contained in:
@@ -2,7 +2,7 @@ from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.algorithm.neat import NEAT
|
||||
from tensorneat.genome import DefaultGenome, BiasNode
|
||||
|
||||
from tensorneat.problem.rl_env import BraxEnv
|
||||
from tensorneat.problem.rl import BraxEnv
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
@@ -22,7 +22,7 @@ if __name__ == "__main__":
|
||||
low_bounds=[-1, -1],
|
||||
upper_bounds=[1, 1],
|
||||
method="sample",
|
||||
num_samples=1000,
|
||||
num_samples=100,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -42,7 +42,7 @@ if __name__ == "__main__":
|
||||
),
|
||||
),
|
||||
problem=custom_problem,
|
||||
generation_limit=100,
|
||||
generation_limit=50,
|
||||
fitness_target=-1e-4,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
@@ -7,8 +7,6 @@ import numpy as np
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm
|
||||
from tensorneat.problem import BaseProblem
|
||||
from tensorneat.problem.rl_env import RLEnv
|
||||
from tensorneat.problem.func_fit import FuncFit
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
@@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass):
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
save_dir=None,
|
||||
is_save: bool = False,
|
||||
save_dir=None,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
np.random.seed(self.seed)
|
||||
|
||||
# TODO: make each algorithm's input_num and output_num
|
||||
assert (
|
||||
algorithm.num_inputs == self.problem.input_shape[-1]
|
||||
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||
@@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_genome = None
|
||||
self.best_fitness = float("-inf")
|
||||
self.generation_timestamp = None
|
||||
self.pre_update = pre_update
|
||||
self.update_batch_size = update_batch_size
|
||||
if pre_update:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert problem.record_episode, "record_episode must be True"
|
||||
self.fetch_data = lambda episode: episode["obs"]
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert problem.return_data, "return_data must be True"
|
||||
self.fetch_data = lambda data: data
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
@@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
|
||||
if self.pre_update:
|
||||
# initial with mean = 0 and std = 1
|
||||
state = state.register(
|
||||
data=jax.random.normal(
|
||||
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
|
||||
)
|
||||
)
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
|
||||
@@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass):
|
||||
state, pop
|
||||
)
|
||||
|
||||
if self.pre_update:
|
||||
# update the population
|
||||
_, pop_transformed = jax.vmap(
|
||||
self.algorithm.update_by_batch, in_axes=(None, None, 0)
|
||||
)(state, state.data, pop_transformed)
|
||||
|
||||
# raw_data: (Pop, Batch, num_inputs)
|
||||
fitnesses, raw_data = jax.vmap(
|
||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||
|
||||
# update population
|
||||
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
|
||||
state, pop_transformed
|
||||
)
|
||||
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
|
||||
|
||||
# update data for next generation
|
||||
data = self.fetch_data(raw_data)
|
||||
assert (
|
||||
data.ndim == 3
|
||||
and data.shape[0] == self.pop_size
|
||||
and data.shape[2] == self.algorithm.num_inputs
|
||||
)
|
||||
# reshape to (Pop * Batch, num_inputs)
|
||||
data = data.reshape(
|
||||
data.shape[0] * data.shape[1], self.algorithm.num_inputs
|
||||
)
|
||||
# shuffle
|
||||
data = jax.random.permutation(randkey_, data, axis=0)
|
||||
# cutoff or expand
|
||||
if data.shape[0] >= self.update_batch_size:
|
||||
data = data[: self.update_batch_size] # cutoff
|
||||
else:
|
||||
data = (
|
||||
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
|
||||
) # expand
|
||||
state = state.update(data=data)
|
||||
|
||||
else:
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
|
||||
# replace nan with -inf
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .base import BaseProblem
|
||||
from .rl import *
|
||||
from .func_fit import *
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .func_fit import FuncFit
|
||||
from .xor import XOR
|
||||
from .xor3d import XOR3d
|
||||
from .custom import CustomFuncFit
|
||||
from .func_fit import FuncFit
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import Callable, Union, List, Tuple, Sequence
|
||||
|
||||
import jax
|
||||
from typing import Callable, Union, List, Tuple
|
||||
from jax import vmap, Array, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self, error_method: str = "mse", return_data: bool = False):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
self.return_data = return_data
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.return_data:
|
||||
return -loss, self.inputs
|
||||
else:
|
||||
return -loss
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
if self.return_data:
|
||||
loss, _ = self.evaluate(state, randkey, act_func, params)
|
||||
else:
|
||||
loss = self.evaluate(state, randkey, act_func, params)
|
||||
loss = -loss
|
||||
fitness = self.evaluate(state, randkey, act_func, params)
|
||||
|
||||
loss = -fitness
|
||||
|
||||
msg = ""
|
||||
for i in range(inputs.shape[0]):
|
||||
|
||||
3
tensorneat/problem/rl/__init__.py
Normal file
3
tensorneat/problem/rl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gymnax import GymNaxEnv
|
||||
from .brax import BraxEnv
|
||||
from .rl_jit import RLEnv
|
||||
@@ -6,7 +6,7 @@ from .rl_jit import RLEnv
|
||||
class GymNaxEnv(RLEnv):
|
||||
def __init__(self, env_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
|
||||
self.env, self.env_params = gymnax.make(env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
@@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
raise NotImplementedError
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
class RLEnv(BaseProblem):
|
||||
@@ -15,7 +15,6 @@ class RLEnv(BaseProblem):
|
||||
self,
|
||||
max_step=1000,
|
||||
repeat_times=1,
|
||||
record_episode=False,
|
||||
action_policy: Callable = None,
|
||||
obs_normalization: bool = False,
|
||||
sample_policy: Callable = None,
|
||||
@@ -34,7 +33,6 @@ class RLEnv(BaseProblem):
|
||||
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
self.record_episode = record_episode
|
||||
self.repeat_times = repeat_times
|
||||
self.action_policy = action_policy
|
||||
|
||||
@@ -57,11 +55,11 @@ class RLEnv(BaseProblem):
|
||||
) # ignore act_func
|
||||
|
||||
def sample(rk):
|
||||
return self.evaluate_once(
|
||||
return self._evaluate_once(
|
||||
state, rk, dummy_act_func, None, dummy_sample_func, True
|
||||
)
|
||||
|
||||
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
|
||||
rewards, episodes = jax.jit(vmap(sample))(keys)
|
||||
|
||||
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
||||
obs = obs.reshape(
|
||||
@@ -88,47 +86,21 @@ class RLEnv(BaseProblem):
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
keys = jax.random.split(randkey, self.repeat_times)
|
||||
if self.record_episode:
|
||||
rewards, episodes = jax.vmap(
|
||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
True,
|
||||
self.obs_normalization,
|
||||
)
|
||||
rewards = vmap(
|
||||
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
False,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
episodes["obs"] = episodes["obs"].reshape(
|
||||
self.max_step * self.repeat_times, *self.input_shape
|
||||
)
|
||||
episodes["action"] = episodes["action"].reshape(
|
||||
self.max_step * self.repeat_times, *self.output_shape
|
||||
)
|
||||
episodes["reward"] = episodes["reward"].reshape(
|
||||
self.max_step * self.repeat_times,
|
||||
)
|
||||
return rewards.mean()
|
||||
|
||||
return rewards.mean(), episodes
|
||||
|
||||
else:
|
||||
rewards = jax.vmap(
|
||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
False,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
return rewards.mean()
|
||||
|
||||
def evaluate_once(
|
||||
def _evaluate_once(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
@@ -1,3 +0,0 @@
|
||||
from .gymnax_env import GymNaxEnv
|
||||
from .brax_env import BraxEnv
|
||||
from .rl_jit import RLEnv
|
||||
@@ -1,62 +0,0 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jumanji
|
||||
|
||||
from tensorneat.common import State
|
||||
from ..rl_jit import RLEnv
|
||||
|
||||
|
||||
class Jumanji_2048(RLEnv):
|
||||
def __init__(
|
||||
self, guarantee_invalid_action=True, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.guarantee_invalid_action = guarantee_invalid_action
|
||||
self.env = jumanji.make("Game2048-v1")
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
action_mask = env_state["action_mask"]
|
||||
|
||||
###################################################################
|
||||
|
||||
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
|
||||
|
||||
###################################################################
|
||||
if self.guarantee_invalid_action:
|
||||
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
|
||||
action = jnp.argmax(score_with_mask)
|
||||
else:
|
||||
action = jnp.argmax(action)
|
||||
|
||||
done = ~action_mask[action]
|
||||
|
||||
env_state, timestep = self.env.step(env_state, action)
|
||||
reward = timestep["reward"]
|
||||
|
||||
board, action_mask = timestep["observation"]
|
||||
extras = timestep["extras"]
|
||||
|
||||
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
|
||||
|
||||
return board.reshape(-1), env_state, reward, done, extras
|
||||
|
||||
def env_reset(self, randkey):
|
||||
env_state, timestep = self.env.reset(randkey)
|
||||
step_type = timestep["step_type"]
|
||||
reward = timestep["reward"]
|
||||
discount = timestep["discount"]
|
||||
observation = timestep["observation"]
|
||||
extras = timestep["extras"]
|
||||
board, action_mask = observation
|
||||
|
||||
return board.reshape(-1), env_state
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16,)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4,)
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
Reference in New Issue
Block a user