update problem and pipeline
This commit is contained in:
@@ -2,7 +2,7 @@ from tensorneat.pipeline import Pipeline
|
|||||||
from tensorneat.algorithm.neat import NEAT
|
from tensorneat.algorithm.neat import NEAT
|
||||||
from tensorneat.genome import DefaultGenome, BiasNode
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
from tensorneat.problem.rl_env import BraxEnv
|
from tensorneat.problem.rl import BraxEnv
|
||||||
from tensorneat.common import Act, Agg
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ if __name__ == "__main__":
|
|||||||
low_bounds=[-1, -1],
|
low_bounds=[-1, -1],
|
||||||
upper_bounds=[1, 1],
|
upper_bounds=[1, 1],
|
||||||
method="sample",
|
method="sample",
|
||||||
num_samples=1000,
|
num_samples=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
@@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
problem=custom_problem,
|
problem=custom_problem,
|
||||||
generation_limit=100,
|
generation_limit=50,
|
||||||
fitness_target=-1e-4,
|
fitness_target=-1e-4,
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorneat.algorithm import BaseAlgorithm
|
from tensorneat.algorithm import BaseAlgorithm
|
||||||
from tensorneat.problem import BaseProblem
|
from tensorneat.problem import BaseProblem
|
||||||
from tensorneat.problem.rl_env import RLEnv
|
|
||||||
from tensorneat.problem.func_fit import FuncFit
|
|
||||||
from tensorneat.common import State, StatefulBaseClass
|
from tensorneat.common import State, StatefulBaseClass
|
||||||
|
|
||||||
|
|
||||||
@@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass):
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
fitness_target: float = 1,
|
fitness_target: float = 1,
|
||||||
generation_limit: int = 1000,
|
generation_limit: int = 1000,
|
||||||
pre_update: bool = False,
|
|
||||||
update_batch_size: int = 10000,
|
|
||||||
save_dir=None,
|
|
||||||
is_save: bool = False,
|
is_save: bool = False,
|
||||||
|
save_dir=None,
|
||||||
):
|
):
|
||||||
assert problem.jitable, "Currently, problem must be jitable"
|
assert problem.jitable, "Currently, problem must be jitable"
|
||||||
|
|
||||||
@@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass):
|
|||||||
|
|
||||||
np.random.seed(self.seed)
|
np.random.seed(self.seed)
|
||||||
|
|
||||||
# TODO: make each algorithm's input_num and output_num
|
|
||||||
assert (
|
assert (
|
||||||
algorithm.num_inputs == self.problem.input_shape[-1]
|
algorithm.num_inputs == self.problem.input_shape[-1]
|
||||||
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||||
@@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass):
|
|||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
self.best_fitness = float("-inf")
|
self.best_fitness = float("-inf")
|
||||||
self.generation_timestamp = None
|
self.generation_timestamp = None
|
||||||
self.pre_update = pre_update
|
|
||||||
self.update_batch_size = update_batch_size
|
|
||||||
if pre_update:
|
|
||||||
if isinstance(problem, RLEnv):
|
|
||||||
assert problem.record_episode, "record_episode must be True"
|
|
||||||
self.fetch_data = lambda episode: episode["obs"]
|
|
||||||
elif isinstance(problem, FuncFit):
|
|
||||||
assert problem.return_data, "return_data must be True"
|
|
||||||
self.fetch_data = lambda data: data
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
if isinstance(problem, RLEnv):
|
|
||||||
assert not problem.record_episode, "record_episode must be False"
|
|
||||||
elif isinstance(problem, FuncFit):
|
|
||||||
assert not problem.return_data, "return_data must be False"
|
|
||||||
self.is_save = is_save
|
self.is_save = is_save
|
||||||
|
|
||||||
if is_save:
|
if is_save:
|
||||||
@@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass):
|
|||||||
print("initializing")
|
print("initializing")
|
||||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||||
|
|
||||||
if self.pre_update:
|
|
||||||
# initial with mean = 0 and std = 1
|
|
||||||
state = state.register(
|
|
||||||
data=jax.random.normal(
|
|
||||||
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
state = self.algorithm.setup(state)
|
state = self.algorithm.setup(state)
|
||||||
state = self.problem.setup(state)
|
state = self.problem.setup(state)
|
||||||
|
|
||||||
@@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass):
|
|||||||
state, pop
|
state, pop
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_update:
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||||
# update the population
|
state, keys, self.algorithm.forward, pop_transformed
|
||||||
_, pop_transformed = jax.vmap(
|
)
|
||||||
self.algorithm.update_by_batch, in_axes=(None, None, 0)
|
|
||||||
)(state, state.data, pop_transformed)
|
|
||||||
|
|
||||||
# raw_data: (Pop, Batch, num_inputs)
|
|
||||||
fitnesses, raw_data = jax.vmap(
|
|
||||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
|
||||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
|
||||||
|
|
||||||
# update population
|
|
||||||
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
|
|
||||||
state, pop_transformed
|
|
||||||
)
|
|
||||||
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
|
|
||||||
|
|
||||||
# update data for next generation
|
|
||||||
data = self.fetch_data(raw_data)
|
|
||||||
assert (
|
|
||||||
data.ndim == 3
|
|
||||||
and data.shape[0] == self.pop_size
|
|
||||||
and data.shape[2] == self.algorithm.num_inputs
|
|
||||||
)
|
|
||||||
# reshape to (Pop * Batch, num_inputs)
|
|
||||||
data = data.reshape(
|
|
||||||
data.shape[0] * data.shape[1], self.algorithm.num_inputs
|
|
||||||
)
|
|
||||||
# shuffle
|
|
||||||
data = jax.random.permutation(randkey_, data, axis=0)
|
|
||||||
# cutoff or expand
|
|
||||||
if data.shape[0] >= self.update_batch_size:
|
|
||||||
data = data[: self.update_batch_size] # cutoff
|
|
||||||
else:
|
|
||||||
data = (
|
|
||||||
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
|
|
||||||
) # expand
|
|
||||||
state = state.update(data=data)
|
|
||||||
|
|
||||||
else:
|
|
||||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
|
||||||
state, keys, self.algorithm.forward, pop_transformed
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace nan with -inf
|
# replace nan with -inf
|
||||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
from .base import BaseProblem
|
from .base import BaseProblem
|
||||||
|
from .rl import *
|
||||||
|
from .func_fit import *
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .func_fit import FuncFit
|
|
||||||
from .xor import XOR
|
from .xor import XOR
|
||||||
from .xor3d import XOR3d
|
from .xor3d import XOR3d
|
||||||
from .custom import CustomFuncFit
|
from .custom import CustomFuncFit
|
||||||
|
from .func_fit import FuncFit
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import Callable, Union, List, Tuple, Sequence
|
from typing import Callable, Union, List, Tuple
|
||||||
|
|
||||||
import jax
|
|
||||||
from jax import vmap, Array, numpy as jnp
|
from jax import vmap, Array, numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from ..base import BaseProblem
|
||||||
from tensorneat.common import State
|
from tensorneat.common import State
|
||||||
from .. import BaseProblem
|
|
||||||
|
|
||||||
|
|
||||||
class FuncFit(BaseProblem):
|
class FuncFit(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
def __init__(self, error_method: str = "mse", return_data: bool = False):
|
def __init__(self, error_method: str = "mse"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||||
self.error_method = error_method
|
self.error_method = error_method
|
||||||
self.return_data = return_data
|
|
||||||
|
|
||||||
def setup(self, state: State = State()):
|
def setup(self, state: State = State()):
|
||||||
return state
|
return state
|
||||||
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if self.return_data:
|
return -loss
|
||||||
return -loss, self.inputs
|
|
||||||
else:
|
|
||||||
return -loss
|
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||||
state, params, self.inputs
|
state, params, self.inputs
|
||||||
)
|
)
|
||||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
if self.return_data:
|
fitness = self.evaluate(state, randkey, act_func, params)
|
||||||
loss, _ = self.evaluate(state, randkey, act_func, params)
|
|
||||||
else:
|
loss = -fitness
|
||||||
loss = self.evaluate(state, randkey, act_func, params)
|
|
||||||
loss = -loss
|
|
||||||
|
|
||||||
msg = ""
|
msg = ""
|
||||||
for i in range(inputs.shape[0]):
|
for i in range(inputs.shape[0]):
|
||||||
|
|||||||
3
tensorneat/problem/rl/__init__.py
Normal file
3
tensorneat/problem/rl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .gymnax import GymNaxEnv
|
||||||
|
from .brax import BraxEnv
|
||||||
|
from .rl_jit import RLEnv
|
||||||
@@ -6,7 +6,7 @@ from .rl_jit import RLEnv
|
|||||||
class GymNaxEnv(RLEnv):
|
class GymNaxEnv(RLEnv):
|
||||||
def __init__(self, env_name, *args, **kwargs):
|
def __init__(self, env_name, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
|
||||||
self.env, self.env_params = gymnax.make(env_name)
|
self.env, self.env_params = gymnax.make(env_name)
|
||||||
|
|
||||||
def env_step(self, randkey, env_state, action):
|
def env_step(self, randkey, env_state, action):
|
||||||
@@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv):
|
|||||||
return self.env.action_space(self.env_params).shape
|
return self.env.action_space(self.env_params).shape
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
raise NotImplementedError
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
from jax import vmap, numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..base import BaseProblem
|
||||||
from tensorneat.common import State
|
from tensorneat.common import State
|
||||||
from .. import BaseProblem
|
|
||||||
|
|
||||||
|
|
||||||
class RLEnv(BaseProblem):
|
class RLEnv(BaseProblem):
|
||||||
@@ -15,7 +15,6 @@ class RLEnv(BaseProblem):
|
|||||||
self,
|
self,
|
||||||
max_step=1000,
|
max_step=1000,
|
||||||
repeat_times=1,
|
repeat_times=1,
|
||||||
record_episode=False,
|
|
||||||
action_policy: Callable = None,
|
action_policy: Callable = None,
|
||||||
obs_normalization: bool = False,
|
obs_normalization: bool = False,
|
||||||
sample_policy: Callable = None,
|
sample_policy: Callable = None,
|
||||||
@@ -34,7 +33,6 @@ class RLEnv(BaseProblem):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
self.record_episode = record_episode
|
|
||||||
self.repeat_times = repeat_times
|
self.repeat_times = repeat_times
|
||||||
self.action_policy = action_policy
|
self.action_policy = action_policy
|
||||||
|
|
||||||
@@ -57,11 +55,11 @@ class RLEnv(BaseProblem):
|
|||||||
) # ignore act_func
|
) # ignore act_func
|
||||||
|
|
||||||
def sample(rk):
|
def sample(rk):
|
||||||
return self.evaluate_once(
|
return self._evaluate_once(
|
||||||
state, rk, dummy_act_func, None, dummy_sample_func, True
|
state, rk, dummy_act_func, None, dummy_sample_func, True
|
||||||
)
|
)
|
||||||
|
|
||||||
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
|
rewards, episodes = jax.jit(vmap(sample))(keys)
|
||||||
|
|
||||||
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
||||||
obs = obs.reshape(
|
obs = obs.reshape(
|
||||||
@@ -88,47 +86,21 @@ class RLEnv(BaseProblem):
|
|||||||
|
|
||||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||||
keys = jax.random.split(randkey, self.repeat_times)
|
keys = jax.random.split(randkey, self.repeat_times)
|
||||||
if self.record_episode:
|
rewards = vmap(
|
||||||
rewards, episodes = jax.vmap(
|
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
)(
|
||||||
)(
|
state,
|
||||||
state,
|
keys,
|
||||||
keys,
|
act_func,
|
||||||
act_func,
|
params,
|
||||||
params,
|
self.action_policy,
|
||||||
self.action_policy,
|
False,
|
||||||
True,
|
self.obs_normalization,
|
||||||
self.obs_normalization,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
episodes["obs"] = episodes["obs"].reshape(
|
return rewards.mean()
|
||||||
self.max_step * self.repeat_times, *self.input_shape
|
|
||||||
)
|
|
||||||
episodes["action"] = episodes["action"].reshape(
|
|
||||||
self.max_step * self.repeat_times, *self.output_shape
|
|
||||||
)
|
|
||||||
episodes["reward"] = episodes["reward"].reshape(
|
|
||||||
self.max_step * self.repeat_times,
|
|
||||||
)
|
|
||||||
|
|
||||||
return rewards.mean(), episodes
|
def _evaluate_once(
|
||||||
|
|
||||||
else:
|
|
||||||
rewards = jax.vmap(
|
|
||||||
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
|
||||||
)(
|
|
||||||
state,
|
|
||||||
keys,
|
|
||||||
act_func,
|
|
||||||
params,
|
|
||||||
self.action_policy,
|
|
||||||
False,
|
|
||||||
self.obs_normalization,
|
|
||||||
)
|
|
||||||
|
|
||||||
return rewards.mean()
|
|
||||||
|
|
||||||
def evaluate_once(
|
|
||||||
self,
|
self,
|
||||||
state,
|
state,
|
||||||
randkey,
|
randkey,
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .gymnax_env import GymNaxEnv
|
|
||||||
from .brax_env import BraxEnv
|
|
||||||
from .rl_jit import RLEnv
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import jax, jax.numpy as jnp
|
|
||||||
import jumanji
|
|
||||||
|
|
||||||
from tensorneat.common import State
|
|
||||||
from ..rl_jit import RLEnv
|
|
||||||
|
|
||||||
|
|
||||||
class Jumanji_2048(RLEnv):
|
|
||||||
def __init__(
|
|
||||||
self, guarantee_invalid_action=True, *args, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.guarantee_invalid_action = guarantee_invalid_action
|
|
||||||
self.env = jumanji.make("Game2048-v1")
|
|
||||||
|
|
||||||
def env_step(self, randkey, env_state, action):
|
|
||||||
action_mask = env_state["action_mask"]
|
|
||||||
|
|
||||||
###################################################################
|
|
||||||
|
|
||||||
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
|
|
||||||
|
|
||||||
###################################################################
|
|
||||||
if self.guarantee_invalid_action:
|
|
||||||
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
|
|
||||||
action = jnp.argmax(score_with_mask)
|
|
||||||
else:
|
|
||||||
action = jnp.argmax(action)
|
|
||||||
|
|
||||||
done = ~action_mask[action]
|
|
||||||
|
|
||||||
env_state, timestep = self.env.step(env_state, action)
|
|
||||||
reward = timestep["reward"]
|
|
||||||
|
|
||||||
board, action_mask = timestep["observation"]
|
|
||||||
extras = timestep["extras"]
|
|
||||||
|
|
||||||
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
|
|
||||||
|
|
||||||
return board.reshape(-1), env_state, reward, done, extras
|
|
||||||
|
|
||||||
def env_reset(self, randkey):
|
|
||||||
env_state, timestep = self.env.reset(randkey)
|
|
||||||
step_type = timestep["step_type"]
|
|
||||||
reward = timestep["reward"]
|
|
||||||
discount = timestep["discount"]
|
|
||||||
observation = timestep["observation"]
|
|
||||||
extras = timestep["extras"]
|
|
||||||
board, action_mask = observation
|
|
||||||
|
|
||||||
return board.reshape(-1), env_state
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_shape(self):
|
|
||||||
return (16,)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_shape(self):
|
|
||||||
return (4,)
|
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
|
||||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
|
||||||
Reference in New Issue
Block a user