add License and pyproject.toml
This commit is contained in:
3
src/tensorneat/problem/__init__.py
Normal file
3
src/tensorneat/problem/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseProblem
|
||||
from .rl import *
|
||||
from .func_fit import *
|
||||
35
src/tensorneat/problem/base.py
Normal file
35
src/tensorneat/problem/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Callable
|
||||
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseProblem(StatefulBaseClass):
|
||||
jitable = None
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
"""evaluate one individual"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
"""
|
||||
The input shape for the problem to evaluate
|
||||
In RL problem, it is the observation space
|
||||
In function fitting problem, it is the input shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
"""
|
||||
The output shape for the problem to evaluate
|
||||
In RL problem, it is the action space
|
||||
In function fitting problem, it is the output shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
|
||||
"""
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
raise NotImplementedError
|
||||
4
src/tensorneat/problem/func_fit/__init__.py
Normal file
4
src/tensorneat/problem/func_fit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .xor import XOR
|
||||
from .xor3d import XOR3d
|
||||
from .custom import CustomFuncFit
|
||||
from .func_fit import FuncFit
|
||||
117
src/tensorneat/problem/func_fit/custom.py
Normal file
117
src/tensorneat/problem/func_fit/custom.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Callable, Union, List, Tuple
|
||||
from jax import vmap, Array, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class CustomFuncFit(FuncFit):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable,
|
||||
low_bounds: Union[List, Tuple, Array],
|
||||
upper_bounds: Union[List, Tuple, Array],
|
||||
method: str = "sample",
|
||||
num_samples: int = 100,
|
||||
step_size: Array = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(low_bounds, list) or isinstance(low_bounds, tuple):
|
||||
low_bounds = np.array(low_bounds, dtype=np.float32)
|
||||
if isinstance(upper_bounds, list) or isinstance(upper_bounds, tuple):
|
||||
upper_bounds = np.array(upper_bounds, dtype=np.float32)
|
||||
|
||||
try:
|
||||
out = func(low_bounds)
|
||||
except Exception as e:
|
||||
raise ValueError(f"func(low_bounds) raise an exception: {e}")
|
||||
assert low_bounds.shape == upper_bounds.shape
|
||||
|
||||
assert method in {"sample", "grid"}
|
||||
|
||||
self.func = func
|
||||
self.low_bounds = low_bounds
|
||||
self.upper_bounds = upper_bounds
|
||||
|
||||
self.method = method
|
||||
self.num_samples = num_samples
|
||||
self.step_size = step_size
|
||||
|
||||
self.generate_dataset()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def generate_dataset(self):
|
||||
|
||||
if self.method == "sample":
|
||||
assert (
|
||||
self.num_samples > 0
|
||||
), f"num_samples must be positive, got {self.num_samples}"
|
||||
|
||||
inputs = np.zeros(
|
||||
(self.num_samples, self.low_bounds.shape[0]), dtype=np.float32
|
||||
)
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
inputs[:, i] = np.random.uniform(
|
||||
low=self.low_bounds[i],
|
||||
high=self.upper_bounds[i],
|
||||
size=(self.num_samples,),
|
||||
)
|
||||
elif self.method == "grid":
|
||||
assert (
|
||||
self.step_size is not None
|
||||
), "step_size must be provided when method is 'grid'"
|
||||
assert (
|
||||
self.step_size.shape == self.low_bounds.shape
|
||||
), "step_size must have the same shape as low_bounds"
|
||||
assert np.all(self.step_size > 0), "step_size must be positive"
|
||||
|
||||
inputs = np.zeros((1, 1))
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
new_col = np.arange(
|
||||
self.low_bounds[i], self.upper_bounds[i], self.step_size[i]
|
||||
)
|
||||
inputs = cartesian_product(inputs, new_col[:, None])
|
||||
inputs = inputs[:, 1:]
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
outputs = vmap(self.func)(inputs)
|
||||
|
||||
self.data_inputs = jnp.array(inputs)
|
||||
self.data_outputs = jnp.array(outputs)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self.data_inputs
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self.data_outputs
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.data_inputs.shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.data_outputs.shape
|
||||
|
||||
|
||||
def cartesian_product(arr1, arr2):
|
||||
assert (
|
||||
arr1.ndim == arr2.ndim
|
||||
), "arr1 and arr2 must have the same number of dimensions"
|
||||
assert arr1.ndim <= 2, "arr1 and arr2 must have at most 2 dimensions"
|
||||
|
||||
len1 = arr1.shape[0]
|
||||
len2 = arr2.shape[0]
|
||||
|
||||
repeated_arr1 = np.repeat(arr1, len2, axis=0)
|
||||
tiled_arr2 = np.tile(arr2, (len1, 1))
|
||||
|
||||
new_arr = np.concatenate((repeated_arr1, tiled_arr2), axis=1)
|
||||
return new_arr
|
||||
72
src/tensorneat/problem/func_fit/func_fit.py
Normal file
72
src/tensorneat/problem/func_fit/func_fit.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
|
||||
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
|
||||
if self.error_method == "mse":
|
||||
loss = jnp.mean((predict - self.targets) ** 2)
|
||||
|
||||
elif self.error_method == "rmse":
|
||||
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
|
||||
|
||||
elif self.error_method == "mae":
|
||||
loss = jnp.mean(jnp.abs(predict - self.targets))
|
||||
|
||||
elif self.error_method == "mape":
|
||||
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
fitness = self.evaluate(state, randkey, act_func, params)
|
||||
|
||||
loss = -fitness
|
||||
|
||||
msg = ""
|
||||
for i in range(inputs.shape[0]):
|
||||
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
|
||||
msg += f"loss: {loss}\n"
|
||||
print(msg)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
27
src/tensorneat/problem/func_fit/xor.py
Normal file
27
src/tensorneat/problem/func_fit/xor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array(
|
||||
[[0], [1], [1], [0]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 4, 2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 4, 1
|
||||
36
src/tensorneat/problem/func_fit/xor3d.py
Normal file
36
src/tensorneat/problem/func_fit/xor3d.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR3d(FuncFit):
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array(
|
||||
[[0], [1], [1], [0], [1], [0], [0], [1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 8, 3
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 8, 1
|
||||
3
src/tensorneat/problem/rl/__init__.py
Normal file
3
src/tensorneat/problem/rl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gymnax import GymNaxEnv
|
||||
from .brax import BraxEnv
|
||||
from .rl_jit import RLEnv
|
||||
83
src/tensorneat/problem/rl/brax.py
Normal file
83
src/tensorneat/problem/rl/brax.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import jax.numpy as jnp
|
||||
from brax import envs
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class BraxEnv(RLEnv):
|
||||
def __init__(
|
||||
self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.env_name = env_name
|
||||
self.env = envs.create(env_name=env_name, backend=backend)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
state = self.env.step(env_state, action)
|
||||
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
|
||||
|
||||
def env_reset(self, randkey):
|
||||
init_state = self.env.reset(randkey)
|
||||
return init_state.obs, init_state
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (self.env.observation_size,)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (self.env.action_size,)
|
||||
|
||||
def show(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
save_path=None,
|
||||
height=480,
|
||||
width=480,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
from brax.io import image
|
||||
|
||||
obs, env_state = self.reset(randkey)
|
||||
reward, done = 0.0, False
|
||||
state_histories = [env_state.pipeline_state]
|
||||
|
||||
def step(key, env_state, obs):
|
||||
key, _ = jax.random.split(key)
|
||||
|
||||
if self.action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = self.action_policy(key, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
|
||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||
return key, env_state, obs, r, done
|
||||
|
||||
jit_step = jax.jit(step)
|
||||
|
||||
for _ in range(self.max_step):
|
||||
key, env_state, obs, r, done = jit_step(randkey, env_state, obs)
|
||||
state_histories.append(env_state.pipeline_state)
|
||||
reward += r
|
||||
if done:
|
||||
break
|
||||
|
||||
imgs = image.render_array(
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||
)
|
||||
|
||||
if save_path is None:
|
||||
save_path = f"{self.env_name}.gif"
|
||||
|
||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||
|
||||
print("Gif saved to: ", save_path)
|
||||
print("Total reward: ", reward)
|
||||
27
src/tensorneat/problem/rl/gymnax.py
Normal file
27
src/tensorneat/problem/rl/gymnax.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import gymnax
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
def __init__(self, env_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
|
||||
self.env, self.env_params = gymnax.make(env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
return self.env.step(randkey, env_state, action, self.env_params)
|
||||
|
||||
def env_reset(self, randkey):
|
||||
return self.env.reset(randkey, self.env_params)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.env.observation_space(self.env_params).shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
209
src/tensorneat/problem/rl/rl_jit.py
Normal file
209
src/tensorneat/problem/rl/rl_jit.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
|
||||
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_step=1000,
|
||||
repeat_times=1,
|
||||
action_policy: Callable = None,
|
||||
obs_normalization: bool = False,
|
||||
sample_policy: Callable = None,
|
||||
sample_episodes: int = 0,
|
||||
):
|
||||
"""
|
||||
action_policy take three args:
|
||||
randkey, forward_func, obs
|
||||
randkey is a random key for jax.random
|
||||
forward_func is a function which receive obs and return action forward_func(obs) - > action
|
||||
obs is the observation of the environment
|
||||
|
||||
sample_policy take two args:
|
||||
randkey, obs -> action
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
self.repeat_times = repeat_times
|
||||
self.action_policy = action_policy
|
||||
|
||||
if obs_normalization:
|
||||
assert sample_policy is not None, "sample_policy must be provided"
|
||||
assert sample_episodes > 0, "sample_size must be greater than 0"
|
||||
self.sample_policy = sample_policy
|
||||
self.sample_episodes = sample_episodes
|
||||
self.obs_normalization = obs_normalization
|
||||
|
||||
def setup(self, state=State()):
|
||||
if self.obs_normalization:
|
||||
print("Sampling episodes for normalization")
|
||||
keys = jax.random.split(state.randkey, self.sample_episodes)
|
||||
dummy_act_func = (
|
||||
lambda s, p, o: o
|
||||
) # receive state, params, obs and return the original obs
|
||||
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
|
||||
rk, obs
|
||||
) # ignore act_func
|
||||
|
||||
def sample(rk):
|
||||
return self._evaluate_once(
|
||||
state, rk, dummy_act_func, None, dummy_sample_func, True
|
||||
)
|
||||
|
||||
rewards, episodes = jax.jit(vmap(sample))(keys)
|
||||
|
||||
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
||||
obs = obs.reshape(
|
||||
-1, *self.input_shape
|
||||
) # shape: (sample_episodes * max_step, *input_shape)
|
||||
|
||||
obs_axis = tuple(range(obs.ndim))
|
||||
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
|
||||
obs = obs[valid_data_flag]
|
||||
|
||||
obs_mean = np.mean(obs, axis=0)
|
||||
obs_std = np.std(obs, axis=0)
|
||||
|
||||
state = state.register(
|
||||
problem_obs_mean=obs_mean,
|
||||
problem_obs_std=obs_std,
|
||||
)
|
||||
|
||||
print("Sampling episodes for normalization finished.")
|
||||
print("valid data count: ", obs.shape[0])
|
||||
print("obs_mean: ", obs_mean)
|
||||
print("obs_std: ", obs_std)
|
||||
return state
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
keys = jax.random.split(randkey, self.repeat_times)
|
||||
rewards = vmap(
|
||||
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
False,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
return rewards.mean()
|
||||
|
||||
def _evaluate_once(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
action_policy,
|
||||
record_episode,
|
||||
normalize_obs=False,
|
||||
):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
if record_episode:
|
||||
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
|
||||
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
|
||||
reward_array = jnp.full((self.max_step,), jnp.nan)
|
||||
episode = {
|
||||
"obs": obs_array,
|
||||
"action": action_array,
|
||||
"reward": reward_array,
|
||||
}
|
||||
else:
|
||||
episode = None
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _, count, _, rk = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
(
|
||||
obs,
|
||||
env_state,
|
||||
rng,
|
||||
done,
|
||||
tr,
|
||||
count,
|
||||
epis,
|
||||
rk,
|
||||
) = carry # tr -> total reward; rk -> randkey
|
||||
|
||||
if normalize_obs:
|
||||
obs = norm_obs(state, obs)
|
||||
|
||||
if action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = action_policy(rk, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(
|
||||
rng, env_state, action
|
||||
)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
|
||||
if record_episode:
|
||||
epis["obs"] = epis["obs"].at[count].set(obs)
|
||||
epis["action"] = epis["action"].at[count].set(action)
|
||||
epis["reward"] = epis["reward"].at[count].set(reward)
|
||||
|
||||
return (
|
||||
next_obs,
|
||||
next_env_state,
|
||||
next_rng,
|
||||
done,
|
||||
tr + reward,
|
||||
count + 1,
|
||||
epis,
|
||||
jax.random.split(rk)[0],
|
||||
)
|
||||
|
||||
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
|
||||
)
|
||||
|
||||
if record_episode:
|
||||
return total_reward, episode
|
||||
else:
|
||||
return total_reward
|
||||
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def env_reset(self, randkey):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def norm_obs(state, obs):
|
||||
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)
|
||||
Reference in New Issue
Block a user