From cef27b56bbb4789084f67c04790c45f1d50d3d9b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 11 Jul 2024 19:34:12 +0800 Subject: [PATCH] update problem and pipeline --- examples/brax/walker2d.py | 2 +- examples/func_fit/custom_func_fit.py | 4 +- tensorneat/pipeline.py | 77 +------------------ tensorneat/problem/__init__.py | 2 + tensorneat/problem/func_fit/__init__.py | 4 +- tensorneat/problem/func_fit/custom.py | 4 +- tensorneat/problem/func_fit/func_fit.py | 18 ++--- tensorneat/problem/rl/__init__.py | 3 + .../{rl_env/brax_env.py => rl/brax.py} | 0 .../{rl_env/gymnax_env.py => rl/gymnax.py} | 4 +- tensorneat/problem/{rl_env => rl}/rl_jit.py | 62 ++++----------- tensorneat/problem/rl_env/__init__.py | 3 - tensorneat/problem/rl_env/jumanji/__init__.py | 0 .../problem/rl_env/jumanji/jumanji_2048.py | 62 --------------- 14 files changed, 40 insertions(+), 205 deletions(-) create mode 100644 tensorneat/problem/rl/__init__.py rename tensorneat/problem/{rl_env/brax_env.py => rl/brax.py} (100%) rename tensorneat/problem/{rl_env/gymnax_env.py => rl/gymnax.py} (87%) rename tensorneat/problem/{rl_env => rl}/rl_jit.py (80%) delete mode 100644 tensorneat/problem/rl_env/__init__.py delete mode 100644 tensorneat/problem/rl_env/jumanji/__init__.py delete mode 100644 tensorneat/problem/rl_env/jumanji/jumanji_2048.py diff --git a/examples/brax/walker2d.py b/examples/brax/walker2d.py index dc8680c..3f03168 100644 --- a/examples/brax/walker2d.py +++ b/examples/brax/walker2d.py @@ -2,7 +2,7 @@ from tensorneat.pipeline import Pipeline from tensorneat.algorithm.neat import NEAT from tensorneat.genome import DefaultGenome, BiasNode -from tensorneat.problem.rl_env import BraxEnv +from tensorneat.problem.rl import BraxEnv from tensorneat.common import Act, Agg import jax, jax.numpy as jnp diff --git a/examples/func_fit/custom_func_fit.py b/examples/func_fit/custom_func_fit.py index 1015fc3..d43101c 100644 --- a/examples/func_fit/custom_func_fit.py +++ b/examples/func_fit/custom_func_fit.py @@ -22,7 +22,7 @@ if __name__ == "__main__": low_bounds=[-1, -1], upper_bounds=[1, 1], method="sample", - num_samples=1000, + num_samples=100, ) pipeline = Pipeline( @@ -42,7 +42,7 @@ if __name__ == "__main__": ), ), problem=custom_problem, - generation_limit=100, + generation_limit=50, fitness_target=-1e-4, seed=42, ) diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 7f13b62..0de5a71 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -7,8 +7,6 @@ import numpy as np from tensorneat.algorithm import BaseAlgorithm from tensorneat.problem import BaseProblem -from tensorneat.problem.rl_env import RLEnv -from tensorneat.problem.func_fit import FuncFit from tensorneat.common import State, StatefulBaseClass @@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass): seed: int = 42, fitness_target: float = 1, generation_limit: int = 1000, - pre_update: bool = False, - update_batch_size: int = 10000, - save_dir=None, is_save: bool = False, + save_dir=None, ): assert problem.jitable, "Currently, problem must be jitable" @@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass): np.random.seed(self.seed) - # TODO: make each algorithm's input_num and output_num assert ( algorithm.num_inputs == self.problem.input_shape[-1] ), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" @@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass): self.best_genome = None self.best_fitness = float("-inf") self.generation_timestamp = None - self.pre_update = pre_update - self.update_batch_size = update_batch_size - if pre_update: - if isinstance(problem, RLEnv): - assert problem.record_episode, "record_episode must be True" - self.fetch_data = lambda episode: episode["obs"] - elif isinstance(problem, FuncFit): - assert problem.return_data, "return_data must be True" - self.fetch_data = lambda data: data - else: - raise NotImplementedError - else: - if isinstance(problem, RLEnv): - assert not problem.record_episode, "record_episode must be False" - elif isinstance(problem, FuncFit): - assert not problem.return_data, "return_data must be False" self.is_save = is_save if is_save: @@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass): print("initializing") state = state.register(randkey=jax.random.PRNGKey(self.seed)) - if self.pre_update: - # initial with mean = 0 and std = 1 - state = state.register( - data=jax.random.normal( - state.randkey, (self.update_batch_size, self.algorithm.num_inputs) - ) - ) - state = self.algorithm.setup(state) state = self.problem.setup(state) @@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass): state, pop ) - if self.pre_update: - # update the population - _, pop_transformed = jax.vmap( - self.algorithm.update_by_batch, in_axes=(None, None, 0) - )(state, state.data, pop_transformed) - - # raw_data: (Pop, Batch, num_inputs) - fitnesses, raw_data = jax.vmap( - self.problem.evaluate, in_axes=(None, 0, None, 0) - )(state, keys, self.algorithm.forward, pop_transformed) - - # update population - pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))( - state, pop_transformed - ) - state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns) - - # update data for next generation - data = self.fetch_data(raw_data) - assert ( - data.ndim == 3 - and data.shape[0] == self.pop_size - and data.shape[2] == self.algorithm.num_inputs - ) - # reshape to (Pop * Batch, num_inputs) - data = data.reshape( - data.shape[0] * data.shape[1], self.algorithm.num_inputs - ) - # shuffle - data = jax.random.permutation(randkey_, data, axis=0) - # cutoff or expand - if data.shape[0] >= self.update_batch_size: - data = data[: self.update_batch_size] # cutoff - else: - data = ( - jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data) - ) # expand - state = state.update(data=data) - - else: - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( - state, keys, self.algorithm.forward, pop_transformed - ) + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( + state, keys, self.algorithm.forward, pop_transformed + ) # replace nan with -inf fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) diff --git a/tensorneat/problem/__init__.py b/tensorneat/problem/__init__.py index 3c2cb07..da97c06 100644 --- a/tensorneat/problem/__init__.py +++ b/tensorneat/problem/__init__.py @@ -1 +1,3 @@ from .base import BaseProblem +from .rl import * +from .func_fit import * diff --git a/tensorneat/problem/func_fit/__init__.py b/tensorneat/problem/func_fit/__init__.py index bf633e0..b2693c3 100644 --- a/tensorneat/problem/func_fit/__init__.py +++ b/tensorneat/problem/func_fit/__init__.py @@ -1,4 +1,4 @@ -from .func_fit import FuncFit from .xor import XOR from .xor3d import XOR3d -from .custom import CustomFuncFit \ No newline at end of file +from .custom import CustomFuncFit +from .func_fit import FuncFit \ No newline at end of file diff --git a/tensorneat/problem/func_fit/custom.py b/tensorneat/problem/func_fit/custom.py index 3a21365..d272ff4 100644 --- a/tensorneat/problem/func_fit/custom.py +++ b/tensorneat/problem/func_fit/custom.py @@ -1,6 +1,4 @@ -from typing import Callable, Union, List, Tuple, Sequence - -import jax +from typing import Callable, Union, List, Tuple from jax import vmap, Array, numpy as jnp import numpy as np diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index 67ff86a..70730e6 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -1,19 +1,18 @@ import jax import jax.numpy as jnp +from ..base import BaseProblem from tensorneat.common import State -from .. import BaseProblem class FuncFit(BaseProblem): jitable = True - def __init__(self, error_method: str = "mse", return_data: bool = False): + def __init__(self, error_method: str = "mse"): super().__init__() assert error_method in {"mse", "rmse", "mae", "mape"} self.error_method = error_method - self.return_data = return_data def setup(self, state: State = State()): return state @@ -39,21 +38,16 @@ class FuncFit(BaseProblem): else: raise NotImplementedError - if self.return_data: - return -loss, self.inputs - else: - return -loss + return -loss def show(self, state, randkey, act_func, params, *args, **kwargs): predict = jax.vmap(act_func, in_axes=(None, None, 0))( state, params, self.inputs ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) - if self.return_data: - loss, _ = self.evaluate(state, randkey, act_func, params) - else: - loss = self.evaluate(state, randkey, act_func, params) - loss = -loss + fitness = self.evaluate(state, randkey, act_func, params) + + loss = -fitness msg = "" for i in range(inputs.shape[0]): diff --git a/tensorneat/problem/rl/__init__.py b/tensorneat/problem/rl/__init__.py new file mode 100644 index 0000000..60f754e --- /dev/null +++ b/tensorneat/problem/rl/__init__.py @@ -0,0 +1,3 @@ +from .gymnax import GymNaxEnv +from .brax import BraxEnv +from .rl_jit import RLEnv \ No newline at end of file diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl/brax.py similarity index 100% rename from tensorneat/problem/rl_env/brax_env.py rename to tensorneat/problem/rl/brax.py diff --git a/tensorneat/problem/rl_env/gymnax_env.py b/tensorneat/problem/rl/gymnax.py similarity index 87% rename from tensorneat/problem/rl_env/gymnax_env.py rename to tensorneat/problem/rl/gymnax.py index da15122..4f17e25 100644 --- a/tensorneat/problem/rl_env/gymnax_env.py +++ b/tensorneat/problem/rl/gymnax.py @@ -6,7 +6,7 @@ from .rl_jit import RLEnv class GymNaxEnv(RLEnv): def __init__(self, env_name, *args, **kwargs): super().__init__(*args, **kwargs) - assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" + assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax." self.env, self.env_params = gymnax.make(env_name) def env_step(self, randkey, env_state, action): @@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv): return self.env.action_space(self.env_params).shape def show(self, state, randkey, act_func, params, *args, **kwargs): - raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") + raise NotImplementedError diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl/rl_jit.py similarity index 80% rename from tensorneat/problem/rl_env/rl_jit.py rename to tensorneat/problem/rl/rl_jit.py index 00439c7..b40b30f 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl/rl_jit.py @@ -1,11 +1,11 @@ from typing import Callable import jax -import jax.numpy as jnp +from jax import vmap, numpy as jnp import numpy as np +from ..base import BaseProblem from tensorneat.common import State -from .. import BaseProblem class RLEnv(BaseProblem): @@ -15,7 +15,6 @@ class RLEnv(BaseProblem): self, max_step=1000, repeat_times=1, - record_episode=False, action_policy: Callable = None, obs_normalization: bool = False, sample_policy: Callable = None, @@ -34,7 +33,6 @@ class RLEnv(BaseProblem): super().__init__() self.max_step = max_step - self.record_episode = record_episode self.repeat_times = repeat_times self.action_policy = action_policy @@ -57,11 +55,11 @@ class RLEnv(BaseProblem): ) # ignore act_func def sample(rk): - return self.evaluate_once( + return self._evaluate_once( state, rk, dummy_act_func, None, dummy_sample_func, True ) - rewards, episodes = jax.jit(jax.vmap(sample))(keys) + rewards, episodes = jax.jit(vmap(sample))(keys) obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape) obs = obs.reshape( @@ -88,47 +86,21 @@ class RLEnv(BaseProblem): def evaluate(self, state: State, randkey, act_func: Callable, params): keys = jax.random.split(randkey, self.repeat_times) - if self.record_episode: - rewards, episodes = jax.vmap( - self.evaluate_once, in_axes=(None, 0, None, None, None, None, None) - )( - state, - keys, - act_func, - params, - self.action_policy, - True, - self.obs_normalization, - ) + rewards = vmap( + self._evaluate_once, in_axes=(None, 0, None, None, None, None, None) + )( + state, + keys, + act_func, + params, + self.action_policy, + False, + self.obs_normalization, + ) - episodes["obs"] = episodes["obs"].reshape( - self.max_step * self.repeat_times, *self.input_shape - ) - episodes["action"] = episodes["action"].reshape( - self.max_step * self.repeat_times, *self.output_shape - ) - episodes["reward"] = episodes["reward"].reshape( - self.max_step * self.repeat_times, - ) + return rewards.mean() - return rewards.mean(), episodes - - else: - rewards = jax.vmap( - self.evaluate_once, in_axes=(None, 0, None, None, None, None, None) - )( - state, - keys, - act_func, - params, - self.action_policy, - False, - self.obs_normalization, - ) - - return rewards.mean() - - def evaluate_once( + def _evaluate_once( self, state, randkey, diff --git a/tensorneat/problem/rl_env/__init__.py b/tensorneat/problem/rl_env/__init__.py deleted file mode 100644 index d473897..0000000 --- a/tensorneat/problem/rl_env/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .gymnax_env import GymNaxEnv -from .brax_env import BraxEnv -from .rl_jit import RLEnv diff --git a/tensorneat/problem/rl_env/jumanji/__init__.py b/tensorneat/problem/rl_env/jumanji/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tensorneat/problem/rl_env/jumanji/jumanji_2048.py b/tensorneat/problem/rl_env/jumanji/jumanji_2048.py deleted file mode 100644 index 97d3e85..0000000 --- a/tensorneat/problem/rl_env/jumanji/jumanji_2048.py +++ /dev/null @@ -1,62 +0,0 @@ -import jax, jax.numpy as jnp -import jumanji - -from tensorneat.common import State -from ..rl_jit import RLEnv - - -class Jumanji_2048(RLEnv): - def __init__( - self, guarantee_invalid_action=True, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.guarantee_invalid_action = guarantee_invalid_action - self.env = jumanji.make("Game2048-v1") - - def env_step(self, randkey, env_state, action): - action_mask = env_state["action_mask"] - - ################################################################### - - # action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)]) - - ################################################################### - if self.guarantee_invalid_action: - score_with_mask = jnp.where(action_mask, action, -jnp.inf) - action = jnp.argmax(score_with_mask) - else: - action = jnp.argmax(action) - - done = ~action_mask[action] - - env_state, timestep = self.env.step(env_state, action) - reward = timestep["reward"] - - board, action_mask = timestep["observation"] - extras = timestep["extras"] - - done = done | (jnp.sum(action_mask) == 0) # all actions of invalid - - return board.reshape(-1), env_state, reward, done, extras - - def env_reset(self, randkey): - env_state, timestep = self.env.reset(randkey) - step_type = timestep["step_type"] - reward = timestep["reward"] - discount = timestep["discount"] - observation = timestep["observation"] - extras = timestep["extras"] - board, action_mask = observation - - return board.reshape(-1), env_state - - @property - def input_shape(self): - return (16,) - - @property - def output_shape(self): - return (4,) - - def show(self, state, randkey, act_func, params, *args, **kwargs): - raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")