add save function in pipeline
This commit is contained in:
@@ -1,2 +1,3 @@
|
|||||||
from .base import BaseNodeGene
|
from .base import BaseNodeGene
|
||||||
from .default import DefaultNodeGene
|
from .default import DefaultNodeGene
|
||||||
|
from .default_without_response import NodeGeneWithoutResponse
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from .base import BaseGenome
|
from .base import BaseGenome
|
||||||
from .default import DefaultGenome
|
from .default import DefaultGenome
|
||||||
from .recurrent import RecurrentGenome
|
from .recurrent import RecurrentGenome
|
||||||
|
from .advance import AdvanceInitialize
|
||||||
@@ -206,14 +206,15 @@ class DefaultGenome(BaseGenome):
|
|||||||
input_idx = self.get_input_idx()
|
input_idx = self.get_input_idx()
|
||||||
output_idx = self.get_output_idx()
|
output_idx = self.get_output_idx()
|
||||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||||
|
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
|
||||||
symbols = {}
|
symbols = {}
|
||||||
for i in network["nodes"]:
|
for i in network["nodes"]:
|
||||||
if i in input_idx:
|
if i in input_idx:
|
||||||
symbols[i] = sp.Symbol(f"i{i}")
|
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
|
||||||
elif i in output_idx:
|
elif i in output_idx:
|
||||||
symbols[i] = sp.Symbol(f"o{i}")
|
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
|
||||||
else: # hidden
|
else: # hidden
|
||||||
symbols[i] = sp.Symbol(f"h{i}")
|
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
|
||||||
|
|
||||||
nodes_exprs = {}
|
nodes_exprs = {}
|
||||||
args_symbols = {}
|
args_symbols = {}
|
||||||
|
|||||||
@@ -4,27 +4,50 @@ from algorithm.neat import *
|
|||||||
from problem.rl_env import BraxEnv
|
from problem.rl_env import BraxEnv
|
||||||
from utils import Act
|
from utils import Act
|
||||||
|
|
||||||
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
def split_right_left(randkey, forward_func, obs):
|
||||||
|
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
|
||||||
|
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
|
||||||
|
right_action_keys = jnp.array([0, 1, 2])
|
||||||
|
left_action_keys = jnp.array([3, 4, 5])
|
||||||
|
|
||||||
|
right_foot_obs = obs
|
||||||
|
left_foot_obs = obs
|
||||||
|
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
|
||||||
|
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
|
||||||
|
|
||||||
|
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
|
||||||
|
# print(right_action.shape)
|
||||||
|
# print(left_action.shape)
|
||||||
|
|
||||||
|
return jnp.concatenate([right_action, left_action])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=NEAT(
|
algorithm=NEAT(
|
||||||
species=DefaultSpecies(
|
species=DefaultSpecies(
|
||||||
genome=DefaultGenome(
|
genome=DefaultGenome(
|
||||||
num_inputs=17,
|
num_inputs=17,
|
||||||
num_outputs=6,
|
num_outputs=3,
|
||||||
max_nodes=50,
|
max_nodes=50,
|
||||||
max_conns=100,
|
max_conns=100,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_options=(Act.tanh,),
|
activation_options=(Act.tanh,),
|
||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
),
|
),
|
||||||
output_transform=Act.tanh
|
output_transform=Act.tanh,
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=1000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
problem=BraxEnv(
|
problem=BraxEnv(
|
||||||
env_name="walker2d",
|
env_name="walker2d",
|
||||||
|
max_step=1000,
|
||||||
|
action_policy=split_right_left
|
||||||
),
|
),
|
||||||
generation_limit=10000,
|
generation_limit=10000,
|
||||||
fitness_target=5000,
|
fitness_target=5000,
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
import time
|
import datetime, time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm import BaseAlgorithm
|
from algorithm import BaseAlgorithm
|
||||||
@@ -19,7 +22,8 @@ class Pipeline(StatefulBaseClass):
|
|||||||
generation_limit: int = 1000,
|
generation_limit: int = 1000,
|
||||||
pre_update: bool = False,
|
pre_update: bool = False,
|
||||||
update_batch_size: int = 10000,
|
update_batch_size: int = 10000,
|
||||||
save_path=None,
|
save_dir=None,
|
||||||
|
is_save: bool = False,
|
||||||
):
|
):
|
||||||
assert problem.jitable, "Currently, problem must be jitable"
|
assert problem.jitable, "Currently, problem must be jitable"
|
||||||
|
|
||||||
@@ -56,7 +60,17 @@ class Pipeline(StatefulBaseClass):
|
|||||||
assert not problem.record_episode, "record_episode must be False"
|
assert not problem.record_episode, "record_episode must be False"
|
||||||
elif isinstance(problem, FuncFit):
|
elif isinstance(problem, FuncFit):
|
||||||
assert not problem.return_data, "return_data must be False"
|
assert not problem.return_data, "return_data must be False"
|
||||||
self.save_path = save_path
|
self.is_save = is_save
|
||||||
|
|
||||||
|
if is_save:
|
||||||
|
if save_dir is None:
|
||||||
|
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||||
|
else:
|
||||||
|
self.save_dir = save_dir
|
||||||
|
print(f"save to {self.save_dir}")
|
||||||
|
if not os.path.exists(self.save_dir):
|
||||||
|
os.makedirs(self.save_dir)
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
print("initializing")
|
print("initializing")
|
||||||
@@ -72,6 +86,15 @@ class Pipeline(StatefulBaseClass):
|
|||||||
|
|
||||||
state = self.algorithm.setup(state)
|
state = self.algorithm.setup(state)
|
||||||
state = self.problem.setup(state)
|
state = self.problem.setup(state)
|
||||||
|
|
||||||
|
if self.is_save:
|
||||||
|
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
|
||||||
|
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
|
||||||
|
f.write(json.dumps(self.show_config(), indent=4))
|
||||||
|
# create log file
|
||||||
|
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
|
||||||
|
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
|
||||||
|
|
||||||
print("initializing finished")
|
print("initializing finished")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@@ -183,16 +206,17 @@ class Pipeline(StatefulBaseClass):
|
|||||||
self.best_fitness = fitnesses[max_idx]
|
self.best_fitness = fitnesses[max_idx]
|
||||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||||
|
|
||||||
|
if self.is_save:
|
||||||
|
best_genome = jax.device_get(self.best_genome)
|
||||||
|
with open(os.path.join(self.save_dir, f"{int(self.algorithm.generation(state))}.npz"), "wb") as f:
|
||||||
|
np.savez(
|
||||||
|
f,
|
||||||
|
nodes=best_genome[0],
|
||||||
|
conns=best_genome[1],
|
||||||
|
fitness=self.best_fitness,
|
||||||
|
)
|
||||||
|
|
||||||
# save best if save path is not None
|
# save best if save path is not None
|
||||||
if self.save_path is not None:
|
|
||||||
best_genome = jax.device_get(self.best_genome)
|
|
||||||
with open(self.save_path, "wb") as f:
|
|
||||||
np.savez(
|
|
||||||
f,
|
|
||||||
nodes=best_genome[0],
|
|
||||||
conns=best_genome[1],
|
|
||||||
fitness=self.best_fitness,
|
|
||||||
)
|
|
||||||
|
|
||||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||||
species_sizes = [int(i) for i in member_count if i > 0]
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
@@ -222,6 +246,13 @@ class Pipeline(StatefulBaseClass):
|
|||||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# append log
|
||||||
|
if self.is_save:
|
||||||
|
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
f"{self.algorithm.generation(state)},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||||
|
)
|
||||||
|
|
||||||
def show(self, state, best, *args, **kwargs):
|
def show(self, state, best, *args, **kwargs):
|
||||||
transformed = self.algorithm.transform(state, best)
|
transformed = self.algorithm.transform(state, best)
|
||||||
self.problem.show(
|
self.problem.show(
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class RLEnv(BaseProblem):
|
|||||||
assert sample_episodes > 0, "sample_size must be greater than 0"
|
assert sample_episodes > 0, "sample_size must be greater than 0"
|
||||||
self.sample_policy = sample_policy
|
self.sample_policy = sample_policy
|
||||||
self.sample_episodes = sample_episodes
|
self.sample_episodes = sample_episodes
|
||||||
self.obs_normalization = obs_normalization
|
self.obs_normalization = obs_normalization
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
if self.obs_normalization:
|
if self.obs_normalization:
|
||||||
|
|||||||
@@ -1,142 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": "<algorithm.neat.genome.default.DefaultGenome at 0x7f6709872650>"
|
|
||||||
},
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import jax, jax.numpy as jnp\n",
|
|
||||||
"from algorithm.neat import *\n",
|
|
||||||
"from utils import Act, Agg\n",
|
|
||||||
"genome = DefaultGenome(\n",
|
|
||||||
" num_inputs=27,\n",
|
|
||||||
" num_outputs=8,\n",
|
|
||||||
" max_nodes=100,\n",
|
|
||||||
" max_conns=200,\n",
|
|
||||||
" node_gene=DefaultNodeGene(\n",
|
|
||||||
" activation_options=(Act.tanh,),\n",
|
|
||||||
" activation_default=Act.tanh,\n",
|
|
||||||
" ),\n",
|
|
||||||
" output_transform=Act.tanh,\n",
|
|
||||||
")\n",
|
|
||||||
"state = genome.setup()\n",
|
|
||||||
"genome"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"ExecuteTime": {
|
|
||||||
"end_time": "2024-06-09T12:08:22.569123400Z",
|
|
||||||
"start_time": "2024-06-09T12:08:19.331863800Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"id": "b2b214a5454c4814"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"state = state.register(data=jnp.zeros((1, 27)))\n",
|
|
||||||
"# try to save the genome object\n",
|
|
||||||
"import pickle\n",
|
|
||||||
"\n",
|
|
||||||
"with open('genome.pkl', 'wb') as f:\n",
|
|
||||||
" genome.__dict__[\"state\"] = state\n",
|
|
||||||
" pickle.dump(genome, f)"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"ExecuteTime": {
|
|
||||||
"end_time": "2024-06-09T12:09:01.943445900Z",
|
|
||||||
"start_time": "2024-06-09T12:09:01.919416Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"id": "28348dfc458e8473"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# try to load the genome object\n",
|
|
||||||
"with open('genome.pkl', 'rb') as f:\n",
|
|
||||||
" genome = pickle.load(f)\n",
|
|
||||||
" state = genome.state\n",
|
|
||||||
" del genome.__dict__[\"state\"]"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"ExecuteTime": {
|
|
||||||
"end_time": "2024-06-09T12:10:28.621539400Z",
|
|
||||||
"start_time": "2024-06-09T12:10:28.612540100Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"id": "c91be9fe3d2b5d5d"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": "State ({'data': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)})"
|
|
||||||
},
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"state"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"ExecuteTime": {
|
|
||||||
"end_time": "2024-06-09T12:10:34.103124Z",
|
|
||||||
"start_time": "2024-06-09T12:10:34.096124300Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"id": "6852e4e58b81dd9"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
},
|
|
||||||
"id": "97a50322218a0427"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 2
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython2",
|
|
||||||
"version": "2.7.6"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -14,7 +14,8 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tanh(z):
|
def tanh(z):
|
||||||
return jnp.tanh(0.6 * z)
|
z = jnp.clip(0.6*z, -3, 3)
|
||||||
|
return jnp.tanh(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sin(z):
|
def sin(z):
|
||||||
|
|||||||
@@ -45,11 +45,15 @@ class SympySigmoid(sp.Function):
|
|||||||
class SympyTanh(sp.Function):
|
class SympyTanh(sp.Function):
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, z):
|
def eval(cls, z):
|
||||||
return sp.tanh(0.6 * z)
|
if z.is_Number:
|
||||||
|
z = SympyClip(0.6 * z, -3, 3)
|
||||||
|
return sp.tanh(z)
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def numerical_eval(z, backend=np):
|
def numerical_eval(z, backend=np):
|
||||||
return backend.tanh(0.6 * z)
|
z = backend.clip(0.6*z, -3, 3)
|
||||||
|
return backend.tanh(z)
|
||||||
|
|
||||||
|
|
||||||
class SympySin(sp.Function):
|
class SympySin(sp.Function):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from . import State
|
from . import State
|
||||||
import pickle
|
import pickle
|
||||||
@@ -18,6 +19,15 @@ class StatefulBaseClass:
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
def show_config(self):
|
||||||
|
config = {}
|
||||||
|
for key, value in self.__dict__.items():
|
||||||
|
if isinstance(value, StatefulBaseClass):
|
||||||
|
config[str(key)] = value.show_config()
|
||||||
|
else:
|
||||||
|
config[str(key)] = str(value)
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user