change repo structure; modify readme

This commit is contained in:
wls2002
2024-03-26 21:58:27 +08:00
parent 6970e6a6d5
commit 47dbcbea80
69 changed files with 74 additions and 60 deletions

View File

@@ -0,0 +1 @@
from .base import BaseProblem

View File

@@ -0,0 +1,39 @@
from typing import Callable
from utils import State
class BaseProblem:
jitable = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
pass
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@property
def input_shape(self):
"""
The input shape for the problem to evaluate
In RL problem, it is the observation space
In function fitting problem, it is the input shape of the function
"""
raise NotImplementedError
@property
def output_shape(self):
"""
The output shape for the problem to evaluate
In RL problem, it is the action space
In function fitting problem, it is the output shape of the function
"""
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -0,0 +1,3 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -0,0 +1,67 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""
for i in range(inputs.shape[0]):
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
@property
def inputs(self):
raise NotImplementedError
@property
def targets(self):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
import numpy as np
from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0]
])
@property
def input_shape(self):
return 4, 2
@property
def output_shape(self):
return 4, 1

View File

@@ -0,0 +1,43 @@
import numpy as np
from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
@property
def input_shape(self):
return 8, 3
@property
def output_shape(self):
return 8, 1

View File

@@ -0,0 +1,2 @@
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv

View File

@@ -0,0 +1,64 @@
import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv
class BraxEnv(RLEnv):
def __init__(self, env_name: str = "ant", backend: str = "generalized"):
super().__init__()
self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
def env_reset(self, randkey):
init_state = self.env.reset(randkey)
return init_state.obs, init_state
@property
def input_shape(self):
return (self.env.observation_size,)
@property
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
import jax
import imageio
import numpy as np
from brax.io import image
from tqdm import tqdm
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
state_histories = []
def step(key, env_state, obs):
key, _ = jax.random.split(key)
action = act_func(state, obs, params)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
while not done:
state_histories.append(env_state.pipeline_state)
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
tqdm(state_histories, desc="Rendering")]
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)

View File

@@ -0,0 +1,28 @@
import gymnax
from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name):
super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -0,0 +1,61 @@
from functools import partial
import jax
from .. import BaseProblem
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
def __init__(self):
super().__init__()
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
_, _, _, _, total_reward = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0)
)
return total_reward
@partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
@partial(jax.jit, static_argnums=(0,))
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError