update problem and pipeline

This commit is contained in:
root
2024-07-11 19:34:12 +08:00
parent be6a67d7e2
commit cef27b56bb
14 changed files with 40 additions and 205 deletions

View File

@@ -2,7 +2,7 @@ from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl_env import BraxEnv
from tensorneat.problem.rl import BraxEnv
from tensorneat.common import Act, Agg
import jax, jax.numpy as jnp

View File

@@ -22,7 +22,7 @@ if __name__ == "__main__":
low_bounds=[-1, -1],
upper_bounds=[1, 1],
method="sample",
num_samples=1000,
num_samples=100,
)
pipeline = Pipeline(
@@ -42,7 +42,7 @@ if __name__ == "__main__":
),
),
problem=custom_problem,
generation_limit=100,
generation_limit=50,
fitness_target=-1e-4,
seed=42,
)

View File

@@ -7,8 +7,6 @@ import numpy as np
from tensorneat.algorithm import BaseAlgorithm
from tensorneat.problem import BaseProblem
from tensorneat.problem.rl_env import RLEnv
from tensorneat.problem.func_fit import FuncFit
from tensorneat.common import State, StatefulBaseClass
@@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass):
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
save_dir=None,
is_save: bool = False,
save_dir=None,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass):
np.random.seed(self.seed)
# TODO: make each algorithm's input_num and output_num
assert (
algorithm.num_inputs == self.problem.input_shape[-1]
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
@@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass):
self.best_genome = None
self.best_fitness = float("-inf")
self.generation_timestamp = None
self.pre_update = pre_update
self.update_batch_size = update_batch_size
if pre_update:
if isinstance(problem, RLEnv):
assert problem.record_episode, "record_episode must be True"
self.fetch_data = lambda episode: episode["obs"]
elif isinstance(problem, FuncFit):
assert problem.return_data, "return_data must be True"
self.fetch_data = lambda data: data
else:
raise NotImplementedError
else:
if isinstance(problem, RLEnv):
assert not problem.record_episode, "record_episode must be False"
elif isinstance(problem, FuncFit):
assert not problem.return_data, "return_data must be False"
self.is_save = is_save
if is_save:
@@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass):
print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed))
if self.pre_update:
# initial with mean = 0 and std = 1
state = state.register(
data=jax.random.normal(
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
)
)
state = self.algorithm.setup(state)
state = self.problem.setup(state)
@@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass):
state, pop
)
if self.pre_update:
# update the population
_, pop_transformed = jax.vmap(
self.algorithm.update_by_batch, in_axes=(None, None, 0)
)(state, state.data, pop_transformed)
# raw_data: (Pop, Batch, num_inputs)
fitnesses, raw_data = jax.vmap(
self.problem.evaluate, in_axes=(None, 0, None, 0)
)(state, keys, self.algorithm.forward, pop_transformed)
# update population
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
state, pop_transformed
)
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
# update data for next generation
data = self.fetch_data(raw_data)
assert (
data.ndim == 3
and data.shape[0] == self.pop_size
and data.shape[2] == self.algorithm.num_inputs
)
# reshape to (Pop * Batch, num_inputs)
data = data.reshape(
data.shape[0] * data.shape[1], self.algorithm.num_inputs
)
# shuffle
data = jax.random.permutation(randkey_, data, axis=0)
# cutoff or expand
if data.shape[0] >= self.update_batch_size:
data = data[: self.update_batch_size] # cutoff
else:
data = (
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
) # expand
state = state.update(data=data)
else:
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
# replace nan with -inf
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)

View File

@@ -1 +1,3 @@
from .base import BaseProblem
from .rl import *
from .func_fit import *

View File

@@ -1,4 +1,4 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d
from .custom import CustomFuncFit
from .custom import CustomFuncFit
from .func_fit import FuncFit

View File

@@ -1,6 +1,4 @@
from typing import Callable, Union, List, Tuple, Sequence
import jax
from typing import Callable, Union, List, Tuple
from jax import vmap, Array, numpy as jnp
import numpy as np

View File

@@ -1,19 +1,18 @@
import jax
import jax.numpy as jnp
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self, error_method: str = "mse", return_data: bool = False):
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
self.return_data = return_data
def setup(self, state: State = State()):
return state
@@ -39,21 +38,16 @@ class FuncFit(BaseProblem):
else:
raise NotImplementedError
if self.return_data:
return -loss, self.inputs
else:
return -loss
return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
fitness = self.evaluate(state, randkey, act_func, params)
loss = -fitness
msg = ""
for i in range(inputs.shape[0]):

View File

@@ -0,0 +1,3 @@
from .gymnax import GymNaxEnv
from .brax import BraxEnv
from .rl_jit import RLEnv

View File

@@ -6,7 +6,7 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name, *args, **kwargs):
super().__init__(*args, **kwargs)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
@@ -24,4 +24,4 @@ class GymNaxEnv(RLEnv):
return self.env.action_space(self.env_params).shape
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
raise NotImplementedError

View File

@@ -1,11 +1,11 @@
from typing import Callable
import jax
import jax.numpy as jnp
from jax import vmap, numpy as jnp
import numpy as np
from ..base import BaseProblem
from tensorneat.common import State
from .. import BaseProblem
class RLEnv(BaseProblem):
@@ -15,7 +15,6 @@ class RLEnv(BaseProblem):
self,
max_step=1000,
repeat_times=1,
record_episode=False,
action_policy: Callable = None,
obs_normalization: bool = False,
sample_policy: Callable = None,
@@ -34,7 +33,6 @@ class RLEnv(BaseProblem):
super().__init__()
self.max_step = max_step
self.record_episode = record_episode
self.repeat_times = repeat_times
self.action_policy = action_policy
@@ -57,11 +55,11 @@ class RLEnv(BaseProblem):
) # ignore act_func
def sample(rk):
return self.evaluate_once(
return self._evaluate_once(
state, rk, dummy_act_func, None, dummy_sample_func, True
)
rewards, episodes = jax.jit(jax.vmap(sample))(keys)
rewards, episodes = jax.jit(vmap(sample))(keys)
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
obs = obs.reshape(
@@ -88,47 +86,21 @@ class RLEnv(BaseProblem):
def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode:
rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
True,
self.obs_normalization,
)
rewards = vmap(
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape
)
episodes["action"] = episodes["action"].reshape(
self.max_step * self.repeat_times, *self.output_shape
)
episodes["reward"] = episodes["reward"].reshape(
self.max_step * self.repeat_times,
)
return rewards.mean()
return rewards.mean(), episodes
else:
rewards = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None, None, None, None)
)(
state,
keys,
act_func,
params,
self.action_policy,
False,
self.obs_normalization,
)
return rewards.mean()
def evaluate_once(
def _evaluate_once(
self,
state,
randkey,

View File

@@ -1,3 +0,0 @@
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv
from .rl_jit import RLEnv

View File

@@ -1,62 +0,0 @@
import jax, jax.numpy as jnp
import jumanji
from tensorneat.common import State
from ..rl_jit import RLEnv
class Jumanji_2048(RLEnv):
def __init__(
self, guarantee_invalid_action=True, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.guarantee_invalid_action = guarantee_invalid_action
self.env = jumanji.make("Game2048-v1")
def env_step(self, randkey, env_state, action):
action_mask = env_state["action_mask"]
###################################################################
# action = jnp.concatenate([action, jnp.full((4 - action.shape[0], ), -99999)])
###################################################################
if self.guarantee_invalid_action:
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
action = jnp.argmax(score_with_mask)
else:
action = jnp.argmax(action)
done = ~action_mask[action]
env_state, timestep = self.env.step(env_state, action)
reward = timestep["reward"]
board, action_mask = timestep["observation"]
extras = timestep["extras"]
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
return board.reshape(-1), env_state, reward, done, extras
def env_reset(self, randkey):
env_state, timestep = self.env.reset(randkey)
step_type = timestep["step_type"]
reward = timestep["reward"]
discount = timestep["discount"]
observation = timestep["observation"]
extras = timestep["extras"]
board, action_mask = observation
return board.reshape(-1), env_state
@property
def input_shape(self):
return (16,)
@property
def output_shape(self):
return (4,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")