finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -1,19 +1,14 @@
from typing import Callable
from config import ProblemConfig
from core.state import State
from utils import State
class BaseProblem:
jitable = None
def __init__(self):
pass
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
raise NotImplementedError
pass
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""

View File

@@ -1,24 +1,27 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
@@ -38,7 +41,7 @@ class FuncFit(BaseProblem):
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""

View File

@@ -1,2 +1,2 @@
from .gymnax_env import GymNaxEnv, GymNaxConfig
from .brax_env import BraxEnv, BraxConfig
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv

View File

@@ -3,7 +3,6 @@ import gymnax
from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name):

View File

@@ -4,8 +4,8 @@ import jax
from .. import BaseProblem
class RLEnv(BaseProblem):
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
@@ -19,9 +19,10 @@ class RLEnv(BaseProblem):
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
action = act_func(state, obs, params)
action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward