new architecture

This commit is contained in:
wls2002
2024-01-27 00:52:39 +08:00
parent 4efe9a53c1
commit aac41a089d
65 changed files with 1651 additions and 1783 deletions

View File

@@ -0,0 +1 @@
from .base import BaseProblem

44
problem/base.py Normal file
View File

@@ -0,0 +1,44 @@
from typing import Callable
from config import ProblemConfig
from core.state import State
class BaseProblem:
jitable = None
def __init__(self):
pass
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
raise NotImplementedError
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@property
def input_shape(self):
"""
The input shape for the problem to evaluate
In RL problem, it is the observation space
In function fitting problem, it is the input shape of the function
"""
raise NotImplementedError
@property
def output_shape(self):
"""
The output shape for the problem to evaluate
In RL problem, it is the action space
In function fitting problem, it is the output shape of the function
"""
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -1,3 +1,3 @@
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -1,42 +1,35 @@
from typing import Callable
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from config import ProblemConfig
from core import Problem, State
from .. import BaseProblem
@dataclass(frozen=True)
class FuncFitConfig(ProblemConfig):
error_method: str = 'mse'
def __post_init__(self):
assert self.error_method in {'mse', 'rmse', 'mae', 'mape'}
class FuncFit(Problem):
class FuncFit(BaseProblem):
jitable = True
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self,
error_method: str = 'mse'
):
super().__init__()
def evaluate(self, randkey, state: State, act_func: Callable, params):
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params)
if self.config.error_method == 'mse':
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.config.error_method == 'rmse':
elif self.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.config.error_method == 'mae':
elif self.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.config.error_method == 'mape':
elif self.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
@@ -44,7 +37,7 @@ class FuncFit(Problem):
return -loss
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)

View File

@@ -1,13 +1,12 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):

View File

@@ -1,13 +1,12 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
@@ -37,8 +36,8 @@ class XOR3d(FuncFit):
@property
def input_shape(self):
return (8, 3)
return 8, 3
@property
def output_shape(self):
return (8, 1)
return 8, 1

View File

@@ -1,28 +1,13 @@
from dataclasses import dataclass
from typing import Callable
import jax.numpy as jnp
from brax import envs
from core import State
from .rl_jit import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class BraxConfig(RLEnvConfig):
env_name: str = "ant"
backend: str = "generalized"
def __post_init__(self):
# TODO: Check if env_name is registered
# assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
pass
from .rl_jit import RLEnv
class BraxEnv(RLEnv):
def __init__(self, config: BraxConfig = BraxConfig()):
super().__init__(config)
self.config = config
self.env = envs.create(env_name=config.env_name, backend=config.backend)
def __init__(self, env_name: str = "ant", backend: str = "generalized"):
super().__init__()
self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
@@ -40,9 +25,7 @@ class BraxEnv(RLEnv):
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state: State, act_func: Callable, params, save_path=None, height=512, width=512,
duration=0.1, *args,
**kwargs):
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
import jax
import imageio
@@ -56,8 +39,7 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
action = act_func(state, obs, params)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
@@ -72,7 +54,6 @@ class BraxEnv(RLEnv):
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)

View File

@@ -1,26 +1,15 @@
from dataclasses import dataclass
from typing import Callable
import gymnax
from core import State
from .rl_jit import RLEnv, RLEnvConfig
from .rl_jit import RLEnv
@dataclass(frozen=True)
class GymNaxConfig(RLEnvConfig):
env_name: str = "CartPole-v1"
def __post_init__(self):
assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
class GymNaxEnv(RLEnv):
def __init__(self, config: GymNaxConfig = GymNaxConfig()):
super().__init__(config)
self.config = config
self.env, self.env_params = gymnax.make(config.env_name)
def __init__(self, env_name):
super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
@@ -36,5 +25,5 @@ class GymNaxEnv(RLEnv):
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state: State, act_func: Callable, params):
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -1,28 +1,18 @@
from dataclasses import dataclass
from typing import Callable
from functools import partial
import jax
from config import ProblemConfig
from .. import BaseProblem
from core import Problem, State
@dataclass(frozen=True)
class RLEnvConfig(ProblemConfig):
output_transform: Callable = lambda x: x
class RLEnv(Problem):
class RLEnv(BaseProblem):
jitable = True
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config
# TODO: move output transform to algorithm
def __init__(self):
super().__init__()
def evaluate(self, randkey, state: State, act_func: Callable, params):
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
@@ -31,8 +21,7 @@ class RLEnv(Problem):
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
action = act_func(state, obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
@@ -67,5 +56,5 @@ class RLEnv(Problem):
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError