update problem and pipeline

This commit is contained in:
root
2024-07-11 19:34:12 +08:00
parent be6a67d7e2
commit cef27b56bb
14 changed files with 40 additions and 205 deletions

View File

@@ -1 +1,3 @@
from .base import BaseProblem
from .rl import *
from .func_fit import *

View File

@@ -1,4 +1,4 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d
from .custom import CustomFuncFit
from .custom import CustomFuncFit
from .func_fit import FuncFit

View File

@@ -1,6 +1,4 @@
from typing import Callable, Union, List, Tuple, Sequence
import jax
from typing import Callable, Union, List, Tuple
from jax import vmap, Array, numpy as jnp
import numpy as np

View File

@@ -1,19 +1,18 @@
import jax
import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse", return_data: bool = False):
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()):
return state
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
else:
raise NotImplementedError
if self.return_data:
return -loss, self.inputs
else:
return -loss
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
fitness = self.evaluate(state, randkey, act_func, params)
loss = -fitness
msg = ""
for i in range(inputs.shape[0]):

View File

@@ -0,0 +1,3 @@
from .gymnax import GymNaxEnv
from .brax import BraxEnv
from .rl_jit import RLEnv

View File

@@ -6,7 +6,7 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name, *args, **kwargs):
super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
@@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv):
return self.env.action_space(self.env_params).shape
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
raise NotImplementedError

View File

@@ -1,11 +1,11 @@
from typing import Callable
import jax
import jax.numpy as jnp
from jax import vmap, numpy as jnp
import numpy as np
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class RLEnv(BaseProblem):
@@ -15,7 +15,6 @@ class RLEnv(BaseProblem):
self,
max_step=1000,
repeat_times=1,
record_episode=False,
action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
@@ -34,7 +33,6 @@ class RLEnv(BaseProblem):
super().__init__()
self.max_step = max_step
self.record_episode = record_episode
self.repeat_times = repeat_times
self.action_policy = action_policy
@@ -57,11 +55,11 @@ class RLEnv(BaseProblem):
) # ignore act_func
def sample(rk):
return self.evaluate_once(
return self._evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
rewards, episodes = jax.jit(vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
@@ -88,47 +86,21 @@ class RLEnv(BaseProblem):
def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode:
rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
True,
self.obs_normalization,
)
rewards = vmap(
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape
)
episodes["action"] = episodes["action"].reshape(
self.max_step * self.repeat_times, *self.output_shape
)
episodes["reward"] = episodes["reward"].reshape(
self.max_step * self.repeat_times,
)
return rewards.mean()
return rewards.mean(), episodes
else:
rewards = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def evaluate_once(
def _evaluate_once(
self,
state,
randkey,

View File

@@ -1,3 +0,0 @@
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv
from .rl_jit import RLEnv

View File

@@ -1,62 +0,0 @@
import jax, jax.numpy as jnp
import jumanji
from tensorneat.common import State
from ..rl_jit import RLEnv
class Jumanji_2048(RLEnv):
def __init__(
self, guarantee_invalid_action=True, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.guarantee_invalid_action = guarantee_invalid_action
self.env = jumanji.make("Game2048-v1")
def env_step(self, randkey, env_state, action):
action_mask = env_state["action_mask"]
###################################################################
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
###################################################################
if self.guarantee_invalid_action:
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
action = jnp.argmax(score_with_mask)
else:
action = jnp.argmax(action)
done = ~action_mask[action]
env_state, timestep = self.env.step(env_state, action)
reward = timestep["reward"]
board, action_mask = timestep["observation"]
extras = timestep["extras"]
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
return board.reshape(-1), env_state, reward, done, extras
def env_reset(self, randkey):
env_state, timestep = self.env.reset(randkey)
step_type = timestep["step_type"]
reward = timestep["reward"]
discount = timestep["discount"]
observation = timestep["observation"]
extras = timestep["extras"]
board, action_mask = observation
return board.reshape(-1), env_state
@property
def input_shape(self):
return (16,)
@property
def output_shape(self):
return (4,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")