update problem and pipeline

This commit is contained in:
root
2024-07-11 19:34:12 +08:00
parent be6a67d7e2
commit cef27b56bb
14 changed files with 40 additions and 205 deletions

View File

@@ -2,7 +2,7 @@ from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl_env import BraxEnv from tensorneat.problem.rl import BraxEnv
from tensorneat.common import Act, Agg from tensorneat.common import Act, Agg
import jax, jax.numpy as jnp import jax, jax.numpy as jnp

View File

@@ -22,7 +22,7 @@ if __name__ == "__main__":
low_bounds=[-1, -1], low_bounds=[-1, -1],
upper_bounds=[1, 1], upper_bounds=[1, 1],
method="sample", method="sample",
num_samples=1000, num_samples=100,
) )
pipeline = Pipeline( pipeline = Pipeline(
@@ -42,7 +42,7 @@ if __name__ == "__main__":
), ),
), ),
problem=custom_problem, problem=custom_problem,
generation_limit=100, generation_limit=50,
fitness_target=-1e-4, fitness_target=-1e-4,
seed=42, seed=42,
) )

View File

@@ -7,8 +7,6 @@ import numpy as np
from tensorneat.algorithm import BaseAlgorithm from tensorneat.algorithm import BaseAlgorithm
from tensorneat.problem import BaseProblem from tensorneat.problem import BaseProblem
from tensorneat.problem.rl_env import RLEnv
from tensorneat.problem.func_fit import FuncFit
from tensorneat.common import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
@@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass):
seed: int = 42, seed: int = 42,
fitness_target: float = 1, fitness_target: float = 1,
generation_limit: int = 1000, generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
save_dir=None,
is_save: bool = False, is_save: bool = False,
save_dir=None,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass):
np.random.seed(self.seed) np.random.seed(self.seed)
# TODO: make each algorithm's input_num and output_num
assert ( assert (
algorithm.num_inputs == self.problem.input_shape[-1] algorithm.num_inputs == self.problem.input_shape[-1]
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" ), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
@@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass):
self.best_genome = None self.best_genome = None
self.best_fitness = float("-inf") self.best_fitness = float("-inf")
self.generation_timestamp = None self.generation_timestamp = None
self.pre_update = pre_update
self.update_batch_size = update_batch_size
if pre_update:
if isinstance(problem, RLEnv):
assert problem.record_episode, "record_episode must be True"
self.fetch_data = lambda episode: episode["obs"]
elif isinstance(problem, FuncFit):
assert problem.return_data, "return_data must be True"
self.fetch_data = lambda data: data
else:
raise NotImplementedError
else:
if isinstance(problem, RLEnv):
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
self.is_save = is_save self.is_save = is_save
if is_save: if is_save:
@@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass):
print("initializing") print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed)) state = state.register(randkey=jax.random.PRNGKey(self.seed))
if self.pre_update:
# initial with mean = 0 and std = 1
state = state.register(
data=jax.random.normal(
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
)
)
state = self.algorithm.setup(state) state = self.algorithm.setup(state)
state = self.problem.setup(state) state = self.problem.setup(state)
@@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass):
state, pop state, pop
) )
if self.pre_update: fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
# update the population state, keys, self.algorithm.forward, pop_transformed
_, pop_transformed = jax.vmap( )
self.algorithm.update_by_batch, in_axes=(None, None, 0)
)(state, state.data, pop_transformed)
# raw_data: (Pop, Batch, num_inputs)
fitnesses, raw_data = jax.vmap(
self.problem.evaluate, in_axes=(None, 0, None, 0)
)(state, keys, self.algorithm.forward, pop_transformed)
# update population
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
state, pop_transformed
)
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
# update data for next generation
data = self.fetch_data(raw_data)
assert (
data.ndim == 3
and data.shape[0] == self.pop_size
and data.shape[2] == self.algorithm.num_inputs
)
# reshape to (Pop * Batch, num_inputs)
data = data.reshape(
data.shape[0] * data.shape[1], self.algorithm.num_inputs
)
# shuffle
data = jax.random.permutation(randkey_, data, axis=0)
# cutoff or expand
if data.shape[0] >= self.update_batch_size:
data = data[: self.update_batch_size] # cutoff
else:
data = (
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
) # expand
state = state.update(data=data)
else:
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
# replace nan with -inf # replace nan with -inf
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)

View File

@@ -1 +1,3 @@
from .base import BaseProblem from .base import BaseProblem
from .rl import *
from .func_fit import *

View File

@@ -1,4 +1,4 @@
from .func_fit import FuncFit
from .xor import XOR from .xor import XOR
from .xor3d import XOR3d from .xor3d import XOR3d
from .custom import CustomFuncFit from .custom import CustomFuncFit
from .func_fit import FuncFit

View File

@@ -1,6 +1,4 @@
from typing import Callable, Union, List, Tuple, Sequence from typing import Callable, Union, List, Tuple
import jax
from jax import vmap, Array, numpy as jnp from jax import vmap, Array, numpy as jnp
import numpy as np import numpy as np

View File

@@ -1,19 +1,18 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State from tensorneat.common import State
from .. import BaseProblem
class FuncFit(BaseProblem): class FuncFit(BaseProblem):
jitable = True jitable = True
def __init__(self, error_method: str = "mse", return_data: bool = False): def __init__(self, error_method: str = "mse"):
super().__init__() super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"} assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()): def setup(self, state: State = State()):
return state return state
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
else: else:
raise NotImplementedError raise NotImplementedError
if self.return_data: return -loss
return -loss, self.inputs
else:
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))( predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs state, params, self.inputs
) )
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data: fitness = self.evaluate(state, randkey, act_func, params)
loss, _ = self.evaluate(state, randkey, act_func, params)
else: loss = -fitness
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
msg = "" msg = ""
for i in range(inputs.shape[0]): for i in range(inputs.shape[0]):

View File

@@ -0,0 +1,3 @@
from .gymnax import GymNaxEnv
from .brax import BraxEnv
from .rl_jit import RLEnv

View File

@@ -6,7 +6,7 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name, *args, **kwargs): def __init__(self, env_name, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
self.env, self.env_params = gymnax.make(env_name) self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action): def env_step(self, randkey, env_state, action):
@@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv):
return self.env.action_space(self.env_params).shape return self.env.action_space(self.env_params).shape
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") raise NotImplementedError

View File

@@ -1,11 +1,11 @@
from typing import Callable from typing import Callable
import jax import jax
import jax.numpy as jnp from jax import vmap, numpy as jnp
import numpy as np import numpy as np
from ..base import BaseProblem
from tensorneat.common import State from tensorneat.common import State
from .. import BaseProblem
class RLEnv(BaseProblem): class RLEnv(BaseProblem):
@@ -15,7 +15,6 @@ class RLEnv(BaseProblem):
self, self,
max_step=1000, max_step=1000,
repeat_times=1, repeat_times=1,
record_episode=False,
action_policy: Callable = None, action_policy: Callable = None,
obs_normalization: bool = False, obs_normalization: bool = False,
sample_policy: Callable = None, sample_policy: Callable = None,
@@ -34,7 +33,6 @@ class RLEnv(BaseProblem):
super().__init__() super().__init__()
self.max_step = max_step self.max_step = max_step
self.record_episode = record_episode
self.repeat_times = repeat_times self.repeat_times = repeat_times
self.action_policy = action_policy self.action_policy = action_policy
@@ -57,11 +55,11 @@ class RLEnv(BaseProblem):
) # ignore act_func ) # ignore act_func
def sample(rk): def sample(rk):
return self.evaluate_once( return self._evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True state, rk, dummy_act_func, None, dummy_sample_func, True
) )
rewards, episodes = jax.jit(jax.vmap(sample))(keys) rewards, episodes = jax.jit(vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape) obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape( obs = obs.reshape(
@@ -88,47 +86,21 @@ class RLEnv(BaseProblem):
def evaluate(self, state: State, randkey, act_func: Callable, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times) keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode: rewards = vmap(
rewards, episodes = jax.vmap( self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None) )(
)( state,
state, keys,
keys, act_func,
act_func, params,
params, self.action_policy,
self.action_policy, False,
True, self.obs_normalization,
self.obs_normalization, )
)
episodes["obs"] = episodes["obs"].reshape( return rewards.mean()
self.max_step * self.repeat_times, *self.input_shape
)
episodes["action"] = episodes["action"].reshape(
self.max_step * self.repeat_times, *self.output_shape
)
episodes["reward"] = episodes["reward"].reshape(
self.max_step * self.repeat_times,
)
return rewards.mean(), episodes def _evaluate_once(
else:
rewards = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def evaluate_once(
self, self,
state, state,
randkey, randkey,

View File

@@ -1,3 +0,0 @@
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv
from .rl_jit import RLEnv

View File

@@ -1,62 +0,0 @@
import jax, jax.numpy as jnp
import jumanji
from tensorneat.common import State
from ..rl_jit import RLEnv
class Jumanji_2048(RLEnv):
def __init__(
self, guarantee_invalid_action=True, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.guarantee_invalid_action = guarantee_invalid_action
self.env = jumanji.make("Game2048-v1")
def env_step(self, randkey, env_state, action):
action_mask = env_state["action_mask"]
###################################################################
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
###################################################################
if self.guarantee_invalid_action:
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
action = jnp.argmax(score_with_mask)
else:
action = jnp.argmax(action)
done = ~action_mask[action]
env_state, timestep = self.env.step(env_state, action)
reward = timestep["reward"]
board, action_mask = timestep["observation"]
extras = timestep["extras"]
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
return board.reshape(-1), env_state, reward, done, extras
def env_reset(self, randkey):
env_state, timestep = self.env.reset(randkey)
step_type = timestep["step_type"]
reward = timestep["reward"]
discount = timestep["discount"]
observation = timestep["observation"]
extras = timestep["extras"]
board, action_mask = observation
return board.reshape(-1), env_state
@property
def input_shape(self):
return (16,)
@property
def output_shape(self):
return (4,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")