add save function in pipeline
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from .base import BaseNodeGene
|
||||
from .default import DefaultNodeGene
|
||||
from .default_without_response import NodeGeneWithoutResponse
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .base import BaseGenome
|
||||
from .default import DefaultGenome
|
||||
from .recurrent import RecurrentGenome
|
||||
from .advance import AdvanceInitialize
|
||||
@@ -206,14 +206,15 @@ class DefaultGenome(BaseGenome):
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[i] = sp.Symbol(f"i{i}")
|
||||
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i}")
|
||||
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
|
||||
else: # hidden
|
||||
symbols[i] = sp.Symbol(f"h{i}")
|
||||
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
|
||||
|
||||
nodes_exprs = {}
|
||||
args_symbols = {}
|
||||
|
||||
@@ -4,27 +4,50 @@ from algorithm.neat import *
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
|
||||
def split_right_left(randkey, forward_func, obs):
|
||||
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
|
||||
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
|
||||
right_action_keys = jnp.array([0, 1, 2])
|
||||
left_action_keys = jnp.array([3, 4, 5])
|
||||
|
||||
right_foot_obs = obs
|
||||
left_foot_obs = obs
|
||||
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
|
||||
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
|
||||
|
||||
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
|
||||
# print(right_action.shape)
|
||||
# print(left_action.shape)
|
||||
|
||||
return jnp.concatenate([right_action, left_action])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=10000,
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name="walker2d",
|
||||
max_step=1000,
|
||||
action_policy=split_right_left
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import time
|
||||
import datetime, time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import BaseAlgorithm
|
||||
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
save_path=None,
|
||||
save_dir=None,
|
||||
is_save: bool = False,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
self.save_path = save_path
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
if save_dir is None:
|
||||
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||
else:
|
||||
self.save_dir = save_dir
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
|
||||
if self.is_save:
|
||||
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
|
||||
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
|
||||
f.write(json.dumps(self.show_config(), indent=4))
|
||||
# create log file
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
|
||||
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
|
||||
|
||||
print("initializing finished")
|
||||
return state
|
||||
|
||||
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# save best if save path is not None
|
||||
if self.save_path is not None:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(self.save_path, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
# append log
|
||||
if self.is_save:
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
self.problem.show(
|
||||
|
||||
@@ -43,7 +43,7 @@ class RLEnv(BaseProblem):
|
||||
assert sample_episodes > 0, "sample_size must be greater than 0"
|
||||
self.sample_policy = sample_policy
|
||||
self.sample_episodes = sample_episodes
|
||||
self.obs_normalization = obs_normalization
|
||||
self.obs_normalization = obs_normalization
|
||||
|
||||
def setup(self, state=State()):
|
||||
if self.obs_normalization:
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"from algorithm.neat import *\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"genome = DefaultGenome(\n",
|
||||
" num_inputs=27,\n",
|
||||
" num_outputs=8,\n",
|
||||
" max_nodes=100,\n",
|
||||
" max_conns=200,\n",
|
||||
" node_gene=DefaultNodeGene(\n",
|
||||
" activation_options=(Act.tanh,),\n",
|
||||
" activation_default=Act.tanh,\n",
|
||||
" ),\n",
|
||||
" output_transform=Act.tanh,\n",
|
||||
")\n",
|
||||
"state = genome.setup()\n",
|
||||
"genome"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:08:22.569123400Z",
|
||||
"start_time": "2024-06-09T12:08:19.331863800Z"
|
||||
}
|
||||
},
|
||||
"id": "b2b214a5454c4814"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state = state.register(data=jnp.zeros((1, 27)))\n",
|
||||
"# try to save the genome object\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"with open('genome.pkl', 'wb') as f:\n",
|
||||
" genome.__dict__[\"state\"] = state\n",
|
||||
" pickle.dump(genome, f)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:09:01.943445900Z",
|
||||
"start_time": "2024-06-09T12:09:01.919416Z"
|
||||
}
|
||||
},
|
||||
"id": "28348dfc458e8473"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# try to load the genome object\n",
|
||||
"with open('genome.pkl', 'rb') as f:\n",
|
||||
" genome = pickle.load(f)\n",
|
||||
" state = genome.state\n",
|
||||
" del genome.__dict__[\"state\"]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:10:28.621539400Z",
|
||||
"start_time": "2024-06-09T12:10:28.612540100Z"
|
||||
}
|
||||
},
|
||||
"id": "c91be9fe3d2b5d5d"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"state"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-09T12:10:34.103124Z",
|
||||
"start_time": "2024-06-09T12:10:34.096124300Z"
|
||||
}
|
||||
},
|
||||
"id": "6852e4e58b81dd9"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "97a50322218a0427"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -14,7 +14,8 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
return jnp.tanh(0.6 * z)
|
||||
z = jnp.clip(0.6*z, -3, 3)
|
||||
return jnp.tanh(z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
|
||||
@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.tanh(0.6 * z)
|
||||
if z.is_Number:
|
||||
z = SympyClip(0.6 * z, -3, 3)
|
||||
return sp.tanh(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
return backend.tanh(0.6 * z)
|
||||
z = backend.clip(0.6*z, -3, 3)
|
||||
return backend.tanh(z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from . import State
|
||||
import pickle
|
||||
@@ -18,6 +19,15 @@ class StatefulBaseClass:
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, StatefulBaseClass):
|
||||
config[str(key)] = value.show_config()
|
||||
else:
|
||||
config[str(key)] = str(value)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
||||
with open(path, "rb") as f:
|
||||
|
||||
Reference in New Issue
Block a user