add save function in pipeline

This commit is contained in:
wls2002
2024-06-16 21:47:53 +08:00
parent b9d6482d11
commit fb2ae5d2fa
10 changed files with 94 additions and 164 deletions

View File

@@ -1,2 +1,3 @@
from .base import BaseNodeGene from .base import BaseNodeGene
from .default import DefaultNodeGene from .default import DefaultNodeGene
from .default_without_response import NodeGeneWithoutResponse

View File

@@ -1,3 +1,4 @@
from .base import BaseGenome from .base import BaseGenome
from .default import DefaultGenome from .default import DefaultGenome
from .recurrent import RecurrentGenome from .recurrent import RecurrentGenome
from .advance import AdvanceInitialize

View File

@@ -206,14 +206,15 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx() input_idx = self.get_input_idx()
output_idx = self.get_output_idx() output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
symbols = {} symbols = {}
for i in network["nodes"]: for i in network["nodes"]:
if i in input_idx: if i in input_idx:
symbols[i] = sp.Symbol(f"i{i}") symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
elif i in output_idx: elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i}") symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden else: # hidden
symbols[i] = sp.Symbol(f"h{i}") symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
nodes_exprs = {} nodes_exprs = {}
args_symbols = {} args_symbols = {}

View File

@@ -4,27 +4,50 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
import jax, jax.numpy as jnp
def split_right_left(randkey, forward_func, obs):
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
right_action_keys = jnp.array([0, 1, 2])
left_action_keys = jnp.array([3, 4, 5])
right_foot_obs = obs
left_foot_obs = obs
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
# print(right_action.shape)
# print(left_action.shape)
return jnp.concatenate([right_action, left_action])
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
genome=DefaultGenome( genome=DefaultGenome(
num_inputs=17, num_inputs=17,
num_outputs=6, num_outputs=3,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
), ),
output_transform=Act.tanh output_transform=Act.tanh,
), ),
pop_size=10000, pop_size=1000,
species_size=10, species_size=10,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(
env_name="walker2d", env_name="walker2d",
max_step=1000,
action_policy=split_right_left
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000, fitness_target=5000,

View File

@@ -1,5 +1,8 @@
import json
import os
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import time import datetime, time
import numpy as np import numpy as np
from algorithm import BaseAlgorithm from algorithm import BaseAlgorithm
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
generation_limit: int = 1000, generation_limit: int = 1000,
pre_update: bool = False, pre_update: bool = False,
update_batch_size: int = 10000, update_batch_size: int = 10000,
save_path=None, save_dir=None,
is_save: bool = False,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
assert not problem.record_episode, "record_episode must be False" assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit): elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False" assert not problem.return_data, "return_data must be False"
self.save_path = save_path self.is_save = is_save
if is_save:
if save_dir is None:
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_dir = f"./{self.__class__.__name__} {now}"
else:
self.save_dir = save_dir
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
def setup(self, state=State()): def setup(self, state=State()):
print("initializing") print("initializing")
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
state = self.algorithm.setup(state) state = self.algorithm.setup(state)
state = self.problem.setup(state) state = self.problem.setup(state)
if self.is_save:
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
f.write(json.dumps(self.show_config(), indent=4))
# create log file
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
print("initializing finished") print("initializing finished")
return state return state
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# save best if save path is not None # save best if save path is not None
if self.save_path is not None:
best_genome = jax.device_get(self.best_genome)
with open(self.save_path, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
member_count = jax.device_get(self.algorithm.member_count(state)) member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0] species_sizes = [int(i) for i in member_count if i > 0]
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
) )
# append log
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
def show(self, state, best, *args, **kwargs): def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best) transformed = self.algorithm.transform(state, best)
self.problem.show( self.problem.show(

View File

@@ -43,7 +43,7 @@ class RLEnv(BaseProblem):
assert sample_episodes > 0, "sample_size must be greater than 0" assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy self.sample_policy = sample_policy
self.sample_episodes = sample_episodes self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization self.obs_normalization = obs_normalization
def setup(self, state=State()): def setup(self, state=State()):
if self.obs_normalization: if self.obs_normalization:

View File

@@ -1,142 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"data": {
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"from algorithm.neat import *\n",
"from utils import Act, Agg\n",
"genome = DefaultGenome(\n",
" num_inputs=27,\n",
" num_outputs=8,\n",
" max_nodes=100,\n",
" max_conns=200,\n",
" node_gene=DefaultNodeGene(\n",
" activation_options=(Act.tanh,),\n",
" activation_default=Act.tanh,\n",
" ),\n",
" output_transform=Act.tanh,\n",
")\n",
"state = genome.setup()\n",
"genome"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:08:22.569123400Z",
"start_time": "2024-06-09T12:08:19.331863800Z"
}
},
"id": "b2b214a5454c4814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"state = state.register(data=jnp.zeros((1, 27)))\n",
"# try to save the genome object\n",
"import pickle\n",
"\n",
"with open('genome.pkl', 'wb') as f:\n",
" genome.__dict__[\"state\"] = state\n",
" pickle.dump(genome, f)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:09:01.943445900Z",
"start_time": "2024-06-09T12:09:01.919416Z"
}
},
"id": "28348dfc458e8473"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"# try to load the genome object\n",
"with open('genome.pkl', 'rb') as f:\n",
" genome = pickle.load(f)\n",
" state = genome.state\n",
" del genome.__dict__[\"state\"]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:10:28.621539400Z",
"start_time": "2024-06-09T12:10:28.612540100Z"
}
},
"id": "c91be9fe3d2b5d5d"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:10:34.103124Z",
"start_time": "2024-06-09T12:10:34.096124300Z"
}
},
"id": "6852e4e58b81dd9"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "97a50322218a0427"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -14,7 +14,8 @@ class Act:
@staticmethod @staticmethod
def tanh(z): def tanh(z):
return jnp.tanh(0.6 * z) z = jnp.clip(0.6*z, -3, 3)
return jnp.tanh(z)
@staticmethod @staticmethod
def sin(z): def sin(z):

View File

@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
class SympyTanh(sp.Function): class SympyTanh(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
return sp.tanh(0.6 * z) if z.is_Number:
z = SympyClip(0.6 * z, -3, 3)
return sp.tanh(z)
return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
return backend.tanh(0.6 * z) z = backend.clip(0.6*z, -3, 3)
return backend.tanh(z)
class SympySin(sp.Function): class SympySin(sp.Function):

View File

@@ -1,3 +1,4 @@
import json
from typing import Optional from typing import Optional
from . import State from . import State
import pickle import pickle
@@ -18,6 +19,15 @@ class StatefulBaseClass:
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(self, f) pickle.dump(self, f)
def show_config(self):
config = {}
for key, value in self.__dict__.items():
if isinstance(value, StatefulBaseClass):
config[str(key)] = value.show_config()
else:
config[str(key)] = str(value)
return config
@classmethod @classmethod
def load(cls, path: str, with_state: bool = False, warning: bool = True): def load(cls, path: str, with_state: bool = False, warning: bool = True):
with open(path, "rb") as f: with open(path, "rb") as f: