From fb2ae5d2faca1acf3afd0958c12ece5bc0a2e423 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 16 Jun 2024 21:47:53 +0800 Subject: [PATCH] add save function in pipeline --- .../algorithm/neat/gene/node/__init__.py | 1 + tensorneat/algorithm/neat/genome/__init__.py | 1 + tensorneat/algorithm/neat/genome/default.py | 7 +- tensorneat/examples/brax/walker.py | 29 +++- tensorneat/pipeline.py | 55 +++++-- tensorneat/problem/rl_env/rl_jit.py | 2 +- tensorneat/tmp.ipynb | 142 ------------------ tensorneat/utils/activation/act_jnp.py | 3 +- tensorneat/utils/activation/act_sympy.py | 8 +- tensorneat/utils/stateful_class.py | 10 ++ 10 files changed, 94 insertions(+), 164 deletions(-) delete mode 100644 tensorneat/tmp.ipynb diff --git a/tensorneat/algorithm/neat/gene/node/__init__.py b/tensorneat/algorithm/neat/gene/node/__init__.py index b88d714..3c10f2c 100644 --- a/tensorneat/algorithm/neat/gene/node/__init__.py +++ b/tensorneat/algorithm/neat/gene/node/__init__.py @@ -1,2 +1,3 @@ from .base import BaseNodeGene from .default import DefaultNodeGene +from .default_without_response import NodeGeneWithoutResponse diff --git a/tensorneat/algorithm/neat/genome/__init__.py b/tensorneat/algorithm/neat/genome/__init__.py index 0a636dd..74bd7ba 100644 --- a/tensorneat/algorithm/neat/genome/__init__.py +++ b/tensorneat/algorithm/neat/genome/__init__.py @@ -1,3 +1,4 @@ from .base import BaseGenome from .default import DefaultGenome from .recurrent import RecurrentGenome +from .advance import AdvanceInitialize \ No newline at end of file diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index d742b62..137ac95 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -206,14 +206,15 @@ class DefaultGenome(BaseGenome): input_idx = self.get_input_idx() output_idx = self.get_output_idx() order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) + hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx] symbols = {} for i in network["nodes"]: if i in input_idx: - symbols[i] = sp.Symbol(f"i{i}") + symbols[i] = sp.Symbol(f"i{i - min(input_idx)}") elif i in output_idx: - symbols[i] = sp.Symbol(f"o{i}") + symbols[i] = sp.Symbol(f"o{i - min(output_idx)}") else: # hidden - symbols[i] = sp.Symbol(f"h{i}") + symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}") nodes_exprs = {} args_symbols = {} diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py index 1da6b31..3699593 100644 --- a/tensorneat/examples/brax/walker.py +++ b/tensorneat/examples/brax/walker.py @@ -4,27 +4,50 @@ from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act +import jax, jax.numpy as jnp + + +def split_right_left(randkey, forward_func, obs): + right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13]) + left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16]) + right_action_keys = jnp.array([0, 1, 2]) + left_action_keys = jnp.array([3, 4, 5]) + + right_foot_obs = obs + left_foot_obs = obs + left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys]) + left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys]) + + right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs])) + # print(right_action.shape) + # print(left_action.shape) + + return jnp.concatenate([right_action, left_action]) + + if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( genome=DefaultGenome( num_inputs=17, - num_outputs=6, + num_outputs=3, max_nodes=50, max_conns=100, node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, ), - output_transform=Act.tanh + output_transform=Act.tanh, ), - pop_size=10000, + pop_size=1000, species_size=10, ), ), problem=BraxEnv( env_name="walker2d", + max_step=1000, + action_policy=split_right_left ), generation_limit=10000, fitness_target=5000, diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 798d424..2299df2 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -1,5 +1,8 @@ +import json +import os + import jax, jax.numpy as jnp -import time +import datetime, time import numpy as np from algorithm import BaseAlgorithm @@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass): generation_limit: int = 1000, pre_update: bool = False, update_batch_size: int = 10000, - save_path=None, + save_dir=None, + is_save: bool = False, ): assert problem.jitable, "Currently, problem must be jitable" @@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass): assert not problem.record_episode, "record_episode must be False" elif isinstance(problem, FuncFit): assert not problem.return_data, "return_data must be False" - self.save_path = save_path + self.is_save = is_save + + if is_save: + if save_dir is None: + now = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.save_dir = f"./{self.__class__.__name__} {now}" + else: + self.save_dir = save_dir + print(f"save to {self.save_dir}") + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) def setup(self, state=State()): print("initializing") @@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass): state = self.algorithm.setup(state) state = self.problem.setup(state) + + if self.is_save: + # self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl")) + with open(os.path.join(self.save_dir, "config.txt"), "w") as f: + f.write(json.dumps(self.show_config(), indent=4)) + # create log file + with open(os.path.join(self.save_dir, "log.txt"), "w") as f: + f.write("Generation,Max,Min,Mean,Std,Cost Time\n") + print("initializing finished") return state @@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass): self.best_fitness = fitnesses[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx] + if self.is_save: + best_genome = jax.device_get(self.best_genome) + with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f: + np.savez( + f, + nodes=best_genome[0], + conns=best_genome[1], + fitness=self.best_fitness, + ) + # save best if save path is not None - if self.save_path is not None: - best_genome = jax.device_get(self.best_genome) - with open(self.save_path, "wb") as f: - np.savez( - f, - nodes=best_genome[0], - conns=best_genome[1], - fitness=self.best_fitness, - ) member_count = jax.device_get(self.algorithm.member_count(state)) species_sizes = [int(i) for i in member_count if i > 0] @@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass): f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", ) + # append log + if self.is_save: + with open(os.path.join(self.save_dir, "log.txt"), "a") as f: + f.write( + f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n" + ) + def show(self, state, best, *args, **kwargs): transformed = self.algorithm.transform(state, best) self.problem.show( diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 1dbe8b1..709c613 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -43,7 +43,7 @@ class RLEnv(BaseProblem): assert sample_episodes > 0, "sample_size must be greater than 0" self.sample_policy = sample_policy self.sample_episodes = sample_episodes - self.obs_normalization = obs_normalization + self.obs_normalization = obs_normalization def setup(self, state=State()): if self.obs_normalization: diff --git a/tensorneat/tmp.ipynb b/tensorneat/tmp.ipynb deleted file mode 100644 index 600dc3c..0000000 --- a/tensorneat/tmp.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "outputs": [ - { - "data": { - "text/plain": "" - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax, jax.numpy as jnp\n", - "from algorithm.neat import *\n", - "from utils import Act, Agg\n", - "genome = DefaultGenome(\n", - " num_inputs=27,\n", - " num_outputs=8,\n", - " max_nodes=100,\n", - " max_conns=200,\n", - " node_gene=DefaultNodeGene(\n", - " activation_options=(Act.tanh,),\n", - " activation_default=Act.tanh,\n", - " ),\n", - " output_transform=Act.tanh,\n", - ")\n", - "state = genome.setup()\n", - "genome" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-09T12:08:22.569123400Z", - "start_time": "2024-06-09T12:08:19.331863800Z" - } - }, - "id": "b2b214a5454c4814" - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [], - "source": [ - "state = state.register(data=jnp.zeros((1, 27)))\n", - "# try to save the genome object\n", - "import pickle\n", - "\n", - "with open('genome.pkl', 'wb') as f:\n", - " genome.__dict__[\"state\"] = state\n", - " pickle.dump(genome, f)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-09T12:09:01.943445900Z", - "start_time": "2024-06-09T12:09:01.919416Z" - } - }, - "id": "28348dfc458e8473" - }, - { - "cell_type": "code", - "execution_count": 13, - "outputs": [], - "source": [ - "# try to load the genome object\n", - "with open('genome.pkl', 'rb') as f:\n", - " genome = pickle.load(f)\n", - " state = genome.state\n", - " del genome.__dict__[\"state\"]" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-09T12:10:28.621539400Z", - "start_time": "2024-06-09T12:10:28.612540100Z" - } - }, - "id": "c91be9fe3d2b5d5d" - }, - { - "cell_type": "code", - "execution_count": 15, - "outputs": [ - { - "data": { - "text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})" - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "state" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-09T12:10:34.103124Z", - "start_time": "2024-06-09T12:10:34.096124300Z" - } - }, - "id": "6852e4e58b81dd9" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "97a50322218a0427" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tensorneat/utils/activation/act_jnp.py b/tensorneat/utils/activation/act_jnp.py index 7c655e0..2da5d7b 100644 --- a/tensorneat/utils/activation/act_jnp.py +++ b/tensorneat/utils/activation/act_jnp.py @@ -14,7 +14,8 @@ class Act: @staticmethod def tanh(z): - return jnp.tanh(0.6 * z) + z = jnp.clip(0.6*z, -3, 3) + return jnp.tanh(z) @staticmethod def sin(z): diff --git a/tensorneat/utils/activation/act_sympy.py b/tensorneat/utils/activation/act_sympy.py index 7cb043d..9a21b00 100644 --- a/tensorneat/utils/activation/act_sympy.py +++ b/tensorneat/utils/activation/act_sympy.py @@ -45,11 +45,15 @@ class SympySigmoid(sp.Function): class SympyTanh(sp.Function): @classmethod def eval(cls, z): - return sp.tanh(0.6 * z) + if z.is_Number: + z = SympyClip(0.6 * z, -3, 3) + return sp.tanh(z) + return None @staticmethod def numerical_eval(z, backend=np): - return backend.tanh(0.6 * z) + z = backend.clip(0.6*z, -3, 3) + return backend.tanh(z) class SympySin(sp.Function): diff --git a/tensorneat/utils/stateful_class.py b/tensorneat/utils/stateful_class.py index a1c8090..41ce2cc 100644 --- a/tensorneat/utils/stateful_class.py +++ b/tensorneat/utils/stateful_class.py @@ -1,3 +1,4 @@ +import json from typing import Optional from . import State import pickle @@ -18,6 +19,15 @@ class StatefulBaseClass: with open(path, "wb") as f: pickle.dump(self, f) + def show_config(self): + config = {} + for key, value in self.__dict__.items(): + if isinstance(value, StatefulBaseClass): + config[str(key)] = value.show_config() + else: + config[str(key)] = str(value) + return config + @classmethod def load(cls, path: str, with_state: bool = False, warning: bool = True): with open(path, "rb") as f: