change repo structure; modify readme
This commit is contained in:
1
tensorneat/problem/__init__.py
Normal file
1
tensorneat/problem/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .base import BaseProblem
|
||||
39
tensorneat/problem/base.py
Normal file
39
tensorneat/problem/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Callable
|
||||
|
||||
from utils import State
|
||||
|
||||
|
||||
class BaseProblem:
|
||||
jitable = None
|
||||
|
||||
def setup(self, randkey, state: State = State()):
|
||||
"""initialize the state of the problem"""
|
||||
pass
|
||||
|
||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||
"""evaluate one individual"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
"""
|
||||
The input shape for the problem to evaluate
|
||||
In RL problem, it is the observation space
|
||||
In function fitting problem, it is the input shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
"""
|
||||
The output shape for the problem to evaluate
|
||||
In RL problem, it is the action space
|
||||
In function fitting problem, it is the output shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
|
||||
"""
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
raise NotImplementedError
|
||||
3
tensorneat/problem/func_fit/__init__.py
Normal file
3
tensorneat/problem/func_fit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .func_fit import FuncFit
|
||||
from .xor import XOR
|
||||
from .xor3d import XOR3d
|
||||
67
tensorneat/problem/func_fit/func_fit.py
Normal file
67
tensorneat/problem/func_fit/func_fit.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self,
|
||||
error_method: str = 'mse'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
|
||||
self.error_method = error_method
|
||||
|
||||
def setup(self, randkey, state: State = State()):
|
||||
return state
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
|
||||
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
|
||||
|
||||
if self.error_method == 'mse':
|
||||
loss = jnp.mean((predict - self.targets) ** 2)
|
||||
|
||||
elif self.error_method == 'rmse':
|
||||
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
|
||||
|
||||
elif self.error_method == 'mae':
|
||||
loss = jnp.mean(jnp.abs(predict - self.targets))
|
||||
|
||||
elif self.error_method == 'mape':
|
||||
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return -loss
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
loss = -self.evaluate(randkey, state, act_func, params)
|
||||
msg = ""
|
||||
for i in range(inputs.shape[0]):
|
||||
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
|
||||
msg += f"loss: {loss}\n"
|
||||
print(msg)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
35
tensorneat/problem/func_fit/xor.py
Normal file
35
tensorneat/problem/func_fit/xor.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
|
||||
def __init__(self, error_method: str = 'mse'):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array([
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
[1, 1]
|
||||
])
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[0]
|
||||
])
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 4, 2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 4, 1
|
||||
43
tensorneat/problem/func_fit/xor3d.py
Normal file
43
tensorneat/problem/func_fit/xor3d.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR3d(FuncFit):
|
||||
|
||||
def __init__(self, error_method: str = 'mse'):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array([
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
])
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[0],
|
||||
[1],
|
||||
[0],
|
||||
[0],
|
||||
[1]
|
||||
])
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 8, 3
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 8, 1
|
||||
2
tensorneat/problem/rl_env/__init__.py
Normal file
2
tensorneat/problem/rl_env/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .gymnax_env import GymNaxEnv
|
||||
from .brax_env import BraxEnv
|
||||
64
tensorneat/problem/rl_env/brax_env.py
Normal file
64
tensorneat/problem/rl_env/brax_env.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import jax.numpy as jnp
|
||||
from brax import envs
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class BraxEnv(RLEnv):
|
||||
def __init__(self, env_name: str = "ant", backend: str = "generalized"):
|
||||
super().__init__()
|
||||
self.env = envs.create(env_name=env_name, backend=backend)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
state = self.env.step(env_state, action)
|
||||
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
|
||||
|
||||
def env_reset(self, randkey):
|
||||
init_state = self.env.reset(randkey)
|
||||
return init_state.obs, init_state
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (self.env.observation_size,)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (self.env.action_size,)
|
||||
|
||||
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
import numpy as np
|
||||
from brax.io import image
|
||||
from tqdm import tqdm
|
||||
|
||||
obs, env_state = self.reset(randkey)
|
||||
reward, done = 0.0, False
|
||||
state_histories = []
|
||||
|
||||
def step(key, env_state, obs):
|
||||
key, _ = jax.random.split(key)
|
||||
action = act_func(state, obs, params)
|
||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||
return key, env_state, obs, r, done
|
||||
|
||||
while not done:
|
||||
state_histories.append(env_state.pipeline_state)
|
||||
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
|
||||
reward += r
|
||||
|
||||
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
|
||||
tqdm(state_histories, desc="Rendering")]
|
||||
|
||||
def create_gif(image_list, gif_name, duration):
|
||||
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
|
||||
for image in image_list:
|
||||
formatted_image = np.array(image, dtype=np.uint8)
|
||||
writer.append_data(formatted_image)
|
||||
|
||||
create_gif(imgs, save_path, duration=0.1)
|
||||
print("Gif saved to: ", save_path)
|
||||
print("Total reward: ", reward)
|
||||
|
||||
|
||||
28
tensorneat/problem/rl_env/gymnax_env.py
Normal file
28
tensorneat/problem/rl_env/gymnax_env.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import gymnax
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
|
||||
def __init__(self, env_name):
|
||||
super().__init__()
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
||||
self.env, self.env_params = gymnax.make(env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
return self.env.step(randkey, env_state, action, self.env_params)
|
||||
|
||||
def env_reset(self, randkey):
|
||||
return self.env.reset(randkey, self.env_params)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.env.observation_space(self.env_params).shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
61
tensorneat/problem/rl_env/rl_jit.py
Normal file
61
tensorneat/problem/rl_env/rl_jit.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
# TODO: move output transform to algorithm
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _ = carry
|
||||
return ~done
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, _, tr = carry # total reward
|
||||
action = act_func(obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def env_reset(self, randkey):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user