add License and pyproject.toml

This commit is contained in:
root
2024-07-11 23:56:06 +08:00
parent e2869c7562
commit 5fdf7b29bc
60 changed files with 71 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .base import BaseProblem
from .rl import *
from .func_fit import *

View File

@@ -0,0 +1,35 @@
from typing import Callable
from tensorneat.common import State, StatefulBaseClass
class BaseProblem(StatefulBaseClass):
jitable = None
def evaluate(self, state: State, randkey, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@property
def input_shape(self):
"""
The input shape for the problem to evaluate
In RL problem, it is the observation space
In function fitting problem, it is the input shape of the function
"""
raise NotImplementedError
@property
def output_shape(self):
"""
The output shape for the problem to evaluate
In RL problem, it is the action space
In function fitting problem, it is the output shape of the function
"""
raise NotImplementedError
def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -0,0 +1,4 @@
from .xor import XOR
from .xor3d import XOR3d
from .custom import CustomFuncFit
from .func_fit import FuncFit

View File

@@ -0,0 +1,117 @@
from typing import Callable, Union, List, Tuple
from jax import vmap, Array, numpy as jnp
import numpy as np
from .func_fit import FuncFit
class CustomFuncFit(FuncFit):
def __init__(
self,
func: Callable,
low_bounds: Union[List, Tuple, Array],
upper_bounds: Union[List, Tuple, Array],
method: str = "sample",
num_samples: int = 100,
step_size: Array = None,
*args,
**kwargs,
):
if isinstance(low_bounds, list) or isinstance(low_bounds, tuple):
low_bounds = np.array(low_bounds, dtype=np.float32)
if isinstance(upper_bounds, list) or isinstance(upper_bounds, tuple):
upper_bounds = np.array(upper_bounds, dtype=np.float32)
try:
out = func(low_bounds)
except Exception as e:
raise ValueError(f"func(low_bounds) raise an exception: {e}")
assert low_bounds.shape == upper_bounds.shape
assert method in {"sample", "grid"}
self.func = func
self.low_bounds = low_bounds
self.upper_bounds = upper_bounds
self.method = method
self.num_samples = num_samples
self.step_size = step_size
self.generate_dataset()
super().__init__(*args, **kwargs)
def generate_dataset(self):
if self.method == "sample":
assert (
self.num_samples > 0
), f"num_samples must be positive, got {self.num_samples}"
inputs = np.zeros(
(self.num_samples, self.low_bounds.shape[0]), dtype=np.float32
)
for i in range(self.low_bounds.shape[0]):
inputs[:, i] = np.random.uniform(
low=self.low_bounds[i],
high=self.upper_bounds[i],
size=(self.num_samples,),
)
elif self.method == "grid":
assert (
self.step_size is not None
), "step_size must be provided when method is 'grid'"
assert (
self.step_size.shape == self.low_bounds.shape
), "step_size must have the same shape as low_bounds"
assert np.all(self.step_size > 0), "step_size must be positive"
inputs = np.zeros((1, 1))
for i in range(self.low_bounds.shape[0]):
new_col = np.arange(
self.low_bounds[i], self.upper_bounds[i], self.step_size[i]
)
inputs = cartesian_product(inputs, new_col[:, None])
inputs = inputs[:, 1:]
else:
raise ValueError(f"Unknown method: {self.method}")
outputs = vmap(self.func)(inputs)
self.data_inputs = jnp.array(inputs)
self.data_outputs = jnp.array(outputs)
@property
def inputs(self):
return self.data_inputs
@property
def targets(self):
return self.data_outputs
@property
def input_shape(self):
return self.data_inputs.shape
@property
def output_shape(self):
return self.data_outputs.shape
def cartesian_product(arr1, arr2):
assert (
arr1.ndim == arr2.ndim
), "arr1 and arr2 must have the same number of dimensions"
assert arr1.ndim <= 2, "arr1 and arr2 must have at most 2 dimensions"
len1 = arr1.shape[0]
len2 = arr2.shape[0]
repeated_arr1 = np.repeat(arr1, len2, axis=0)
tiled_arr2 = np.tile(arr2, (len1, 1))
new_arr = np.concatenate((repeated_arr1, tiled_arr2), axis=1)
return new_arr

View File

@@ -0,0 +1,72 @@
import jax
import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
def setup(self, state: State = State()):
return state
def evaluate(self, state, randkey, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
if self.error_method == "mse":
loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == "rmse":
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == "mae":
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == "mape":
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
fitness = self.evaluate(state, randkey, act_func, params)
loss = -fitness
msg = ""
for i in range(inputs.shape[0]):
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
@property
def inputs(self):
raise NotImplementedError
@property
def targets(self):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError

View File

@@ -0,0 +1,27 @@
import numpy as np
from .func_fit import FuncFit
class XOR(FuncFit):
@property
def inputs(self):
return np.array(
[[0, 0], [0, 1], [1, 0], [1, 1]],
dtype=np.float32,
)
@property
def targets(self):
return np.array(
[[0], [1], [1], [0]],
dtype=np.float32,
)
@property
def input_shape(self):
return 4, 2
@property
def output_shape(self):
return 4, 1

View File

@@ -0,0 +1,36 @@
import numpy as np
from .func_fit import FuncFit
class XOR3d(FuncFit):
@property
def inputs(self):
return np.array(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
],
dtype=np.float32,
)
@property
def targets(self):
return np.array(
[[0], [1], [1], [0], [1], [0], [0], [1]],
dtype=np.float32,
)
@property
def input_shape(self):
return 8, 3
@property
def output_shape(self):
return 8, 1

View File

@@ -0,0 +1,3 @@
from .gymnax import GymNaxEnv
from .brax import BraxEnv
from .rl_jit import RLEnv

View File

@@ -0,0 +1,83 @@
import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv
class BraxEnv(RLEnv):
def __init__(
self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs
):
super().__init__(*args, **kwargs)
self.env_name = env_name
self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
def env_reset(self, randkey):
init_state = self.env.reset(randkey)
return init_state.obs, init_state
@property
def input_shape(self):
return (self.env.observation_size,)
@property
def output_shape(self):
return (self.env.action_size,)
def show(
self,
state,
randkey,
act_func,
params,
save_path=None,
height=480,
width=480,
*args,
**kwargs,
):
import jax
import imageio
from brax.io import image
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
state_histories = [env_state.pipeline_state]
def step(key, env_state, obs):
key, _ = jax.random.split(key)
if self.action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = self.action_policy(key, forward_func, obs)
else:
action = act_func(state, params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
jit_step = jax.jit(step)
for _ in range(self.max_step):
key, env_state, obs, r, done = jit_step(randkey, env_state, obs)
state_histories.append(env_state.pipeline_state)
reward += r
if done:
break
imgs = image.render_array(
sys=self.env.sys, trajectory=state_histories, height=height, width=width
)
if save_path is None:
save_path = f"{self.env_name}.gif"
imageio.mimsave(save_path, imgs, *args, **kwargs)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)

View File

@@ -0,0 +1,27 @@
import gymnax
from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name, *args, **kwargs):
super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError

View File

@@ -0,0 +1,209 @@
from typing import Callable
import jax
from jax import vmap, numpy as jnp
import numpy as np
from ..base import BaseProblem
from tensorneat.common import State
class RLEnv(BaseProblem):
jitable = True
def __init__(
self,
max_step=1000,
repeat_times=1,
action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
sample_episodes: int = 0,
):
"""
action_policy take three args:
randkey, forward_func, obs
randkey is a random key for jax.random
forward_func is a function which receive obs and return action forward_func(obs) - > action
obs is the observation of the environment
sample_policy take two args:
randkey, obs -> action
"""
super().__init__()
self.max_step = max_step
self.repeat_times = repeat_times
self.action_policy = action_policy
if obs_normalization:
assert sample_policy is not None, "sample_policy must be provided"
assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy
self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization
def setup(self, state=State()):
if self.obs_normalization:
print("Sampling episodes for normalization")
keys = jax.random.split(state.randkey, self.sample_episodes)
dummy_act_func = (
lambda s, p, o: o
) # receive state, params, obs and return the original obs
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
rk, obs
) # ignore act_func
def sample(rk):
return self._evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
-1, *self.input_shape
) # shape: (sample_episodes * max_step, *input_shape)
obs_axis = tuple(range(obs.ndim))
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
obs = obs[valid_data_flag]
obs_mean = np.mean(obs, axis=0)
obs_std = np.std(obs, axis=0)
state = state.register(
problem_obs_mean=obs_mean,
problem_obs_std=obs_std,
)
print("Sampling episodes for normalization finished.")
print("valid data count: ", obs.shape[0])
print("obs_mean: ", obs_mean)
print("obs_std: ", obs_std)
return state
def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
rewards = vmap(
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def _evaluate_once(
self,
state,
randkey,
act_func,
params,
action_policy,
record_episode,
normalize_obs=False,
):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
if record_episode:
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
reward_array = jnp.full((self.max_step,), jnp.nan)
episode = {
"obs": obs_array,
"action": action_array,
"reward": reward_array,
}
else:
episode = None
def cond_func(carry):
_, _, _, done, _, count, _, rk = carry
return ~done & (count < self.max_step)
def body_func(carry):
(
obs,
env_state,
rng,
done,
tr,
count,
epis,
rk,
) = carry # tr -> total reward; rk -> randkey
if normalize_obs:
obs = norm_obs(state, obs)
if action_policy is not None:
forward_func = lambda obs: act_func(state, params, obs)
action = action_policy(rk, forward_func, obs)
else:
action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action
)
next_rng, _ = jax.random.split(rng)
if record_episode:
epis["obs"] = epis["obs"].at[count].set(obs)
epis["action"] = epis["action"].at[count].set(action)
epis["reward"] = epis["reward"].at[count].set(reward)
return (
next_obs,
next_env_state,
next_rng,
done,
tr + reward,
count + 1,
epis,
jax.random.split(rk)[0],
)
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
)
if record_episode:
return total_reward, episode
else:
return total_reward
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError
def norm_obs(state, obs):
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)