add save function in pipeline

This commit is contained in:
wls2002
2024-06-16 21:47:53 +08:00
parent b9d6482d11
commit fb2ae5d2fa
10 changed files with 94 additions and 164 deletions

View File

@@ -1,2 +1,3 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene
from .default_without_response import NodeGeneWithoutResponse

View File

@@ -1,3 +1,4 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome
from .advance import AdvanceInitialize

View File

@@ -206,14 +206,15 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i}")
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i}")
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i}")
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
nodes_exprs = {}
args_symbols = {}

View File

@@ -4,27 +4,50 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
import jax, jax.numpy as jnp
def split_right_left(randkey, forward_func, obs):
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
right_action_keys = jnp.array([0, 1, 2])
left_action_keys = jnp.array([3, 4, 5])
right_foot_obs = obs
left_foot_obs = obs
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
# print(right_action.shape)
# print(left_action.shape)
return jnp.concatenate([right_action, left_action])
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
num_outputs=3,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh
output_transform=Act.tanh,
),
pop_size=10000,
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name="walker2d",
max_step=1000,
action_policy=split_right_left
),
generation_limit=10000,
fitness_target=5000,

View File

@@ -1,5 +1,8 @@
import json
import os
import jax, jax.numpy as jnp
import time
import datetime, time
import numpy as np
from algorithm import BaseAlgorithm
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
save_path=None,
save_dir=None,
is_save: bool = False,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
self.save_path = save_path
self.is_save = is_save
if is_save:
if save_dir is None:
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_dir = f"./{self.__class__.__name__} {now}"
else:
self.save_dir = save_dir
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
def setup(self, state=State()):
print("initializing")
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
state = self.algorithm.setup(state)
state = self.problem.setup(state)
if self.is_save:
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
f.write(json.dumps(self.show_config(), indent=4))
# create log file
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
print("initializing finished")
return state
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
best_genome = jax.device_get(self.best_genome)
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# save best if save path is not None
if self.save_path is not None:
best_genome = jax.device_get(self.best_genome)
with open(self.save_path, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)
# append log
if self.is_save:
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best)
self.problem.show(

View File

@@ -43,7 +43,7 @@ class RLEnv(BaseProblem):
assert sample_episodes > 0, "sample_size must be greater than 0"
self.sample_policy = sample_policy
self.sample_episodes = sample_episodes
self.obs_normalization = obs_normalization
self.obs_normalization = obs_normalization
def setup(self, state=State()):
if self.obs_normalization:

View File

@@ -1,142 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"data": {
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax, jax.numpy as jnp\n",
"from algorithm.neat import *\n",
"from utils import Act, Agg\n",
"genome = DefaultGenome(\n",
" num_inputs=27,\n",
" num_outputs=8,\n",
" max_nodes=100,\n",
" max_conns=200,\n",
" node_gene=DefaultNodeGene(\n",
" activation_options=(Act.tanh,),\n",
" activation_default=Act.tanh,\n",
" ),\n",
" output_transform=Act.tanh,\n",
")\n",
"state = genome.setup()\n",
"genome"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:08:22.569123400Z",
"start_time": "2024-06-09T12:08:19.331863800Z"
}
},
"id": "b2b214a5454c4814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"state = state.register(data=jnp.zeros((1, 27)))\n",
"# try to save the genome object\n",
"import pickle\n",
"\n",
"with open('genome.pkl', 'wb') as f:\n",
" genome.__dict__[\"state\"] = state\n",
" pickle.dump(genome, f)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:09:01.943445900Z",
"start_time": "2024-06-09T12:09:01.919416Z"
}
},
"id": "28348dfc458e8473"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"# try to load the genome object\n",
"with open('genome.pkl', 'rb') as f:\n",
" genome = pickle.load(f)\n",
" state = genome.state\n",
" del genome.__dict__[\"state\"]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:10:28.621539400Z",
"start_time": "2024-06-09T12:10:28.612540100Z"
}
},
"id": "c91be9fe3d2b5d5d"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-09T12:10:34.103124Z",
"start_time": "2024-06-09T12:10:34.096124300Z"
}
},
"id": "6852e4e58b81dd9"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "97a50322218a0427"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -14,7 +14,8 @@ class Act:
@staticmethod
def tanh(z):
return jnp.tanh(0.6 * z)
z = jnp.clip(0.6*z, -3, 3)
return jnp.tanh(z)
@staticmethod
def sin(z):

View File

@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
class SympyTanh(sp.Function):
@classmethod
def eval(cls, z):
return sp.tanh(0.6 * z)
if z.is_Number:
z = SympyClip(0.6 * z, -3, 3)
return sp.tanh(z)
return None
@staticmethod
def numerical_eval(z, backend=np):
return backend.tanh(0.6 * z)
z = backend.clip(0.6*z, -3, 3)
return backend.tanh(z)
class SympySin(sp.Function):

View File

@@ -1,3 +1,4 @@
import json
from typing import Optional
from . import State
import pickle
@@ -18,6 +19,15 @@ class StatefulBaseClass:
with open(path, "wb") as f:
pickle.dump(self, f)
def show_config(self):
config = {}
for key, value in self.__dict__.items():
if isinstance(value, StatefulBaseClass):
config[str(key)] = value.show_config()
else:
config[str(key)] = str(value)
return config
@classmethod
def load(cls, path: str, with_state: bool = False, warning: bool = True):
with open(path, "rb") as f: