From cd92f411dcfbcdc9f818e6cc8ea9fba168c50d35 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 30 May 2024 17:05:56 +0800 Subject: [PATCH] add args record_episode in rl tasks, with related test "test_record_episode.ipynb"; add args return_data in func_fit tasks. --- tensorneat/examples/func_fit/xor.py | 8 +- tensorneat/problem/func_fit/func_fit.py | 8 +- tensorneat/problem/func_fit/xor.py | 2 - tensorneat/problem/func_fit/xor3d.py | 3 - tensorneat/problem/rl_env/brax_env.py | 4 +- tensorneat/problem/rl_env/gymnax_env.py | 4 +- tensorneat/problem/rl_env/rl_jit.py | 47 ++- tensorneat/test/test_record_episode.ipynb | 458 ++++++++++++++++++++++ 8 files changed, 512 insertions(+), 22 deletions(-) create mode 100644 tensorneat/test/test_record_episode.ipynb diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index b628343..3ced8b6 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -20,14 +20,14 @@ if __name__ == "__main__": output_transform=Act.sigmoid, # the activation function for output node mutation=DefaultMutation( node_add=0.05, - conn_add=0.2, + conn_add=0.05, node_delete=0, conn_delete=0, ), ), - pop_size=10000, - species_size=10, - compatibility_threshold=3.5, + pop_size=100, + species_size=20, + compatibility_threshold=2, survival_threshold=0.01, # magic ), ), diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index ff55342..31b5003 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -8,11 +8,12 @@ from .. import BaseProblem class FuncFit(BaseProblem): jitable = True - def __init__(self, error_method: str = "mse"): + def __init__(self, error_method: str = "mse", return_data: bool = False): super().__init__() assert error_method in {"mse", "rmse", "mae", "mape"} self.error_method = error_method + self.return_data = return_data def setup(self, state: State = State()): return state @@ -38,7 +39,10 @@ class FuncFit(BaseProblem): else: raise NotImplementedError - return -loss + if self.return_data: + return -loss, self.inputs + else: + return -loss def show(self, state, randkey, act_func, params, *args, **kwargs): predict = jax.vmap(act_func, in_axes=(None, 0, None))( diff --git a/tensorneat/problem/func_fit/xor.py b/tensorneat/problem/func_fit/xor.py index c9544b9..c798b85 100644 --- a/tensorneat/problem/func_fit/xor.py +++ b/tensorneat/problem/func_fit/xor.py @@ -4,8 +4,6 @@ from .func_fit import FuncFit class XOR(FuncFit): - def __init__(self, error_method: str = "mse"): - super().__init__(error_method) @property def inputs(self): diff --git a/tensorneat/problem/func_fit/xor3d.py b/tensorneat/problem/func_fit/xor3d.py index 7c9877d..94807a0 100644 --- a/tensorneat/problem/func_fit/xor3d.py +++ b/tensorneat/problem/func_fit/xor3d.py @@ -4,9 +4,6 @@ from .func_fit import FuncFit class XOR3d(FuncFit): - def __init__(self, error_method: str = "mse"): - super().__init__(error_method) - @property def inputs(self): return np.array( diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index dcac0b4..8f79c81 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -5,8 +5,8 @@ from .rl_jit import RLEnv class BraxEnv(RLEnv): - def __init__(self, max_step=1000, env_name: str = "ant", backend: str = "generalized"): - super().__init__(max_step) + def __init__(self, max_step=1000, record_episode=False, env_name: str = "ant", backend: str = "generalized"): + super().__init__(max_step, record_episode) self.env = envs.create(env_name=env_name, backend=backend) def env_step(self, randkey, env_state, action): diff --git a/tensorneat/problem/rl_env/gymnax_env.py b/tensorneat/problem/rl_env/gymnax_env.py index e32814c..e0ef323 100644 --- a/tensorneat/problem/rl_env/gymnax_env.py +++ b/tensorneat/problem/rl_env/gymnax_env.py @@ -4,8 +4,8 @@ from .rl_jit import RLEnv class GymNaxEnv(RLEnv): - def __init__(self, env_name, max_step=1000): - super().__init__(max_step) + def __init__(self, env_name, max_step=1000, record_episode=False): + super().__init__(max_step, record_episode) assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" self.env, self.env_params = gymnax.make(env_name) diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 73ba04e..06e020c 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -1,6 +1,7 @@ from functools import partial import jax +import jax.numpy as jnp from .. import BaseProblem @@ -8,32 +9,64 @@ from .. import BaseProblem class RLEnv(BaseProblem): jitable = True - def __init__(self, max_step=1000): + def __init__(self, max_step=1000, record_episode=False): super().__init__() self.max_step = max_step + self.record_episode = record_episode def evaluate(self, state, randkey, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) + if self.record_episode: + obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan) + action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan) + reward_array = jnp.full((self.max_step,), jnp.nan) + episode = { + "obs": obs_array, + "action": action_array, + "reward": reward_array, + } + else: + episode = None + def cond_func(carry): - _, _, _, done, _, count = carry + _, _, _, done, _, count, _ = carry return ~done & (count < self.max_step) def body_func(carry): - obs, env_state, rng, done, tr, count = carry # tr -> total reward + obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward action = act_func(state, obs, params) next_obs, next_env_state, reward, done, _ = self.step( rng, env_state, action ) next_rng, _ = jax.random.split(rng) - return next_obs, next_env_state, next_rng, done, tr + reward, count + 1 - _, _, _, _, total_reward, _ = jax.lax.while_loop( - cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0) + if self.record_episode: + epis["obs"] = epis["obs"].at[count].set(obs) + epis["action"] = epis["action"].at[count].set(action) + epis["reward"] = epis["reward"].at[count].set(reward) + + return ( + next_obs, + next_env_state, + next_rng, + done, + tr + reward, + count + 1, + epis, + ) + + _, _, _, _, total_reward, _, episode = jax.lax.while_loop( + cond_func, + body_func, + (init_obs, init_env_state, rng_episode, False, 0.0, 0, episode), ) - return total_reward + if self.record_episode: + return total_reward, episode + else: + return total_reward # @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): diff --git a/tensorneat/test/test_record_episode.ipynb b/tensorneat/test/test_record_episode.ipynb new file mode 100644 index 0000000..08a300b --- /dev/null +++ b/tensorneat/test/test_record_episode.ipynb @@ -0,0 +1,458 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:04.429593300Z", + "start_time": "2024-05-30T08:53:02.326728600Z" + } + }, + "outputs": [], + "source": [ + "import jax, jax.numpy as jnp\n", + "from utils import State\n", + "from problem.rl_env import BraxEnv\n", + "\n", + "\n", + "def random_policy(state: State, obs, randkey):\n", + " return jax.random.uniform(randkey, (8,)) * 2 - 1" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "data": { + "text/plain": "Array(24.975231, dtype=float32)" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# single evaluation without recording episode\n", + "randkey = jax.random.key(0)\n", + "env_key, policy_key = jax.random.split(randkey)\n", + "problem = BraxEnv(env_name=\"ant\", max_step=100)\n", + "state = problem.setup()\n", + "evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", + " policy_key)\n", + "score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n", + "score" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:18.928839600Z", + "start_time": "2024-05-30T08:53:04.435561800Z" + } + }, + "id": "e62882e782d7e54e" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "data": { + "text/plain": "Array([ -3.274895 , -6.016205 , -6.9032974, 9.187286 ,\n -120.19688 , 12.389805 , -4.6393256, -50.27197 ,\n 9.650737 , -73.77956 ], dtype=float32)" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch evaluation without recording episode\n", + "batch = 10\n", + "env_keys = jax.random.split(env_key, batch)\n", + "policy_keys = jax.random.split(policy_key, batch)\n", + "\n", + "score = jax.jit(\n", + " jax.vmap(\n", + " evaluate_using_random_policy_without_record, \n", + " in_axes=(None, 0, 0)\n", + " ))(\n", + " state, env_keys, policy_keys\n", + " )\n", + "score" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:29.458960600Z", + "start_time": "2024-05-30T08:53:18.928839600Z" + } + }, + "id": "d01997be61038ea2" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "(Array(18.354952, dtype=float32), (100, 27), (100, 8), (100,))" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# single evaluation with recording episode\n", + "randkey = jax.random.key(0)\n", + "env_key, policy_key = jax.random.split(randkey)\n", + "problem = BraxEnv(env_name=\"ant\", max_step=100, record_episode=True)\n", + "evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", + " policy_key)\n", + "score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n", + "score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:40.372461Z", + "start_time": "2024-05-30T08:53:29.455962200Z" + } + }, + "id": "ac6f72e21dd12ee8" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "data": { + "text/plain": "(Array(18.354952, dtype=float32), (10, 100, 27), (10, 100, 8), (10, 100))" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch evaluation without recording episode\n", + "batch = 10\n", + "env_keys = jax.random.split(env_key, batch)\n", + "policy_keys = jax.random.split(policy_key, batch)\n", + "\n", + "scores, episodes = jax.jit(\n", + " jax.vmap(\n", + " evaluate_using_random_policy_with_record, \n", + " in_axes=(None, 0, 0)\n", + " ))(\n", + " state, env_keys, policy_keys\n", + " )\n", + "score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:51.261470500Z", + "start_time": "2024-05-30T08:53:40.368462Z" + } + }, + "id": "1c55341b054ee2e8" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "Array(18.354952, dtype=float32)" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n", + "evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n", + "evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", + "evaluate_using_random_policy_without_record(state, env_key, policy_key)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:53:55.402886Z", + "start_time": "2024-05-30T08:53:51.255470600Z" + } + }, + "id": "274ca4fd0d0b8663" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "for _ in range(20):\n", + " evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", + "# 47s384ms" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:54:42.782425800Z", + "start_time": "2024-05-30T08:53:55.397887700Z" + } + }, + "id": "fdb34361d19cb78d" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "for _ in range(20):\n", + " evaluate_using_random_policy_without_record(state, env_key, policy_key)\n", + "# 48s559ms" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:55:31.344699500Z", + "start_time": "2024-05-30T08:54:42.785428500Z" + } + }, + "id": "9afdf6923051c9f1" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": "Array(9., dtype=float32, weak_type=True)" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# single evaluation without recording episode\n", + "from problem.rl_env import GymNaxEnv\n", + "\n", + "def random_policy(state: State, obs, randkey):\n", + " return jax.random.uniform(randkey, ()) \n", + "\n", + "randkey = jax.random.key(0)\n", + "env_key, policy_key = jax.random.split(randkey)\n", + "problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500)\n", + "state = problem.setup()\n", + "evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", + " policy_key)\n", + "score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)\n", + "score" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:58:46.652406400Z", + "start_time": "2024-05-30T08:58:45.606288800Z" + } + }, + "id": "1de25fb23f519284" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [ + { + "data": { + "text/plain": "Array([13., 19., 11., 12., 14., 21., 13., 11., 11., 28.], dtype=float32, weak_type=True)" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch evaluation without recording episode\n", + "batch = 10\n", + "env_keys = jax.random.split(env_key, batch)\n", + "policy_keys = jax.random.split(policy_key, batch)\n", + "\n", + "score = jax.jit(\n", + " jax.vmap(\n", + " evaluate_using_random_policy_without_record, \n", + " in_axes=(None, 0, 0)\n", + " ))(\n", + " state, env_keys, policy_keys\n", + " )\n", + "score" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:58:58.323528300Z", + "start_time": "2024-05-30T08:58:57.272024400Z" + } + }, + "id": "99e745dce6f2872d" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "(Array(9., dtype=float32, weak_type=True), (500, 4), (500,), (500,))" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# single evaluation with recording episode\n", + "randkey = jax.random.key(0)\n", + "env_key, policy_key = jax.random.split(randkey)\n", + "problem = GymNaxEnv(env_name=\"CartPole-v1\", max_step=500, record_episode=True)\n", + "evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,\n", + " policy_key)\n", + "score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)\n", + "score, episode[\"obs\"].shape, episode[\"action\"].shape, episode[\"reward\"].shape" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:59:18.830495600Z", + "start_time": "2024-05-30T08:59:17.568087200Z" + } + }, + "id": "257e340ebf24c10d" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": "(Array(9., dtype=float32, weak_type=True), (10, 500, 4), (10, 500), (10, 500))" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch evaluation without recording episode\n", + "batch = 10\n", + "env_keys = jax.random.split(env_key, batch)\n", + "policy_keys = jax.random.split(policy_key, batch)\n", + "\n", + "scores, episodes = jax.jit(\n", + " jax.vmap(\n", + " evaluate_using_random_policy_with_record, \n", + " in_axes=(None, 0, 0)\n", + " ))(\n", + " state, env_keys, policy_keys\n", + " )\n", + "score, episodes[\"obs\"].shape, episodes[\"action\"].shape, episodes[\"reward\"].shape" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:59:34.182539200Z", + "start_time": "2024-05-30T08:59:32.956339600Z" + } + }, + "id": "9ba8dc68085cd0fc" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [ + { + "data": { + "text/plain": "Array(9., dtype=float32, weak_type=True)" + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)\n", + "evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)\n", + "evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", + "evaluate_using_random_policy_without_record(state, env_key, policy_key)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T08:59:46.472504900Z", + "start_time": "2024-05-30T08:59:46.419192900Z" + } + }, + "id": "ea01b6663a7ca076" + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [], + "source": [ + "for _ in range(20):\n", + " evaluate_using_random_policy_with_record(state, env_key, policy_key)\n", + "# 48ms" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T09:00:18.905094200Z", + "start_time": "2024-05-30T09:00:18.809970900Z" + } + }, + "id": "989c39c8e20779d0" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [], + "source": [ + "for _ in range(20):\n", + " evaluate_using_random_policy_without_record(state, env_key, policy_key)\n", + "# 43ms" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T09:00:19.240415900Z", + "start_time": "2024-05-30T09:00:19.190416700Z" + } + }, + "id": "bab4782fe674f2d5" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}