add jumanji env;

add repeat times for rl_env
This commit is contained in:
wls2002
2024-06-05 14:24:17 +08:00
parent edfb0596e7
commit 10ec1c2df9
10 changed files with 1615 additions and 7 deletions

View File

@@ -0,0 +1,193 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from . import BaseNodeGene
class MinMaxNode(BaseNodeGene):
"""
Node with normalization before activation.
"""
# alpha and beta is used for normalization, just like BatchNorm
# norm: z = act(agg(inputs) + bias)
# z = (z - min) / (max - min) * (max_out - min_out) + min_out
custom_attrs = ["bias", "aggregation", "activation", "min", "max"]
eps = 1e-6
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
output_range: Tuple[float, float] = (-1, 1),
update_hidden_node: bool = False,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.output_range = output_range
assert (
len(self.output_range) == 2 and self.output_range[0] < self.output_range[1]
)
self.update_hidden_node = update_hidden_node
def new_identity_attrs(self, state):
return jnp.array(
[0, self.aggregation_default, -1, 0, 1]
) # activation=-1 means Act.identity; min=0, max=1 will do not influence
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
agg = jax.random.randint(k2, (), 0, len(self.aggregation_options))
act = jax.random.randint(k3, (), 0, len(self.activation_options))
return jnp.array([bias, agg, act, 0, 1])
def mutate(self, state, randkey, attrs):
k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5)
bias, act, agg, min_, max_ = attrs
bias = mutate_float(
k1,
bias,
self.bias_init_mean,
self.bias_init_std,
self.bias_mutate_power,
self.bias_mutate_rate,
self.bias_replace_rate,
)
agg = mutate_int(
k2, agg, self.aggregation_indices, self.aggregation_replace_rate
)
act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate)
return jnp.array([bias, agg, act, min_, max_])
def distance(self, state, attrs1, attrs2):
bias1, agg1, act1, min1, max1 = attrs1
bias2, agg2, act2, min1, max1 = attrs2
return (
jnp.abs(bias1 - bias2) # bias
+ (agg1 != agg2) # aggregation
+ (act1 != act2) # activation
)
def forward(self, state, attrs, inputs, is_output_node=False):
"""
post_act = (agg(inputs) + bias - mean) / std * alpha + beta
"""
bias, agg, act, min_, max_ = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = bias + z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
)
if self.update_hidden_node:
z = (z - min_) / (max_ - min_) # transform to 01
z = (
z * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
) # transform to output_range
return z
def input_transform(self, state, attrs, inputs):
"""
make transform in the input node.
the normalization also need be done in the first node.
"""
bias, agg, act, min_, max_ = attrs
inputs = (inputs - min_) / (max_ - min_) # transform to 01
inputs = (
inputs * (self.output_range[1] - self.output_range[0])
+ self.output_range[0]
)
return inputs
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
bias, agg, act, min_, max_ = attrs
batch_z = jax.vmap(agg_func, in_axes=(None, 0, None))(
agg, batch_inputs, self.aggregation_options
)
batch_z = bias + batch_z
batch_z = jax.lax.cond(
is_output_node,
lambda: batch_z,
lambda: jax.vmap(act_func, in_axes=(None, 0, None))(
act, batch_z, self.activation_options
),
)
if self.update_hidden_node:
# calculate min, max
min_ = jnp.min(jnp.where(jnp.isnan(batch_z), jnp.inf, batch_z))
max_ = jnp.max(jnp.where(jnp.isnan(batch_z), -jnp.inf, batch_z))
batch_z = (batch_z - min_) / (max_ - min_) # transform to 01
batch_z = (
batch_z * (self.output_range[1] - self.output_range[0])
+ self.output_range[0]
)
# update mean and std to the attrs
attrs = attrs.at[3].set(min_)
attrs = attrs.at[4].set(max_)
return batch_z, attrs
def update_input_transform(self, state, attrs, batch_inputs):
"""
update the attrs for transformation in the input node.
default: do nothing
"""
bias, agg, act, min_, max_ = attrs
# calculate min, max
min_ = jnp.min(jnp.where(jnp.isnan(batch_inputs), jnp.inf, batch_inputs))
max_ = jnp.max(jnp.where(jnp.isnan(batch_inputs), -jnp.inf, batch_inputs))
batch_inputs = (batch_inputs - min_) / (max_ - min_) # transform to 01
batch_inputs = (
batch_inputs * (self.output_range[1] - self.output_range[0])
+ self.output_range[0]
)
# update mean and std to the attrs
attrs = attrs.at[3].set(min_)
attrs = attrs.at[4].set(max_)
return batch_inputs, attrs

View File

@@ -24,6 +24,7 @@ if __name__ == "__main__":
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name="CartPole-v1", env_name="CartPole-v1",
repeat_times=5
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=500, fitness_target=500,

View File

@@ -0,0 +1,46 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=DefaultNodeGene(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product),
),
mutation=DefaultMutation(
node_add=0.03,
conn_add=0.03,
)
),
pop_size=10000,
species_size=100,
survival_threshold=0.01,
),
),
problem=Jumanji_2048(
max_step=10000,
repeat_times=5
),
generation_limit=10000,
fitness_target=13000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,7 @@ class FuncFit(BaseProblem):
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, None, 0))( predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs, params state, params, self.inputs
) )
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data: if self.return_data:

View File

@@ -5,8 +5,8 @@ from .rl_jit import RLEnv
class BraxEnv(RLEnv): class BraxEnv(RLEnv):
def __init__(self, max_step=1000, record_episode=False, env_name: str = "ant", backend: str = "generalized"): def __init__(self, max_step=1000, repeat_times=1, record_episode=False, env_name: str = "ant", backend: str = "generalized"):
super().__init__(max_step, record_episode) super().__init__(max_step, repeat_times, record_episode)
self.env = envs.create(env_name=env_name, backend=backend) self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action): def env_step(self, randkey, env_state, action):

View File

@@ -4,8 +4,8 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name, max_step=1000, record_episode=False): def __init__(self, env_name, max_step=1000, repeat_times=1, record_episode=False):
super().__init__(max_step, record_episode) super().__init__(max_step, repeat_times, record_episode)
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name) self.env, self.env_params = gymnax.make(env_name)

View File

@@ -0,0 +1,56 @@
import jax, jax.numpy as jnp
import jumanji
from utils import State
from ..rl_jit import RLEnv
class Jumanji_2048(RLEnv):
def __init__(
self, max_step=1000, repeat_times=1, record_episode=False, guarantee_invalid_action=True
):
super().__init__(max_step, repeat_times, record_episode)
self.guarantee_invalid_action = guarantee_invalid_action
self.env = jumanji.make("Game2048-v1")
def env_step(self, randkey, env_state, action):
action_mask = env_state["action_mask"]
if self.guarantee_invalid_action:
score_with_mask = jnp.where(action_mask, action, -jnp.inf)
action = jnp.argmax(score_with_mask)
else:
action = jnp.argmax(action)
done = ~action_mask[action]
env_state, timestep = self.env.step(env_state, action)
reward = timestep["reward"]
board, action_mask = timestep["observation"]
extras = timestep["extras"]
done = done | (jnp.sum(action_mask) == 0) # all actions of invalid
return board.reshape(-1), env_state, reward, done, extras
def env_reset(self, randkey):
env_state, timestep = self.env.reset(randkey)
step_type = timestep["step_type"]
reward = timestep["reward"]
discount = timestep["discount"]
observation = timestep["observation"]
extras = timestep["extras"]
board, action_mask = observation
return board.reshape(-1), env_state
@property
def input_shape(self):
return (16,)
@property
def output_shape(self):
return (4,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -1,20 +1,47 @@
from functools import partial from functools import partial
from typing import Callable
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from utils import State
from .. import BaseProblem from .. import BaseProblem
class RLEnv(BaseProblem): class RLEnv(BaseProblem):
jitable = True jitable = True
def __init__(self, max_step=1000, record_episode=False): def __init__(self, max_step=1000, repeat_times=1, record_episode=False):
super().__init__() super().__init__()
self.max_step = max_step self.max_step = max_step
self.record_episode = record_episode self.record_episode = record_episode
self.repeat_times = repeat_times
def evaluate(self, state, randkey, act_func, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
keys = jax.random.split(randkey, self.repeat_times)
if self.record_episode:
rewards, episodes = jax.vmap(
self.evaluate_once, in_axes=(None, 0, None, None)
)(state, keys, act_func, params)
episodes["obs"] = episodes["obs"].reshape(
self.max_step * self.repeat_times, *self.input_shape
)
episodes["action"] = episodes["action"].reshape(
self.max_step * self.repeat_times, *self.output_shape
)
episodes["reward"] = episodes["reward"].reshape(
self.max_step * self.repeat_times,
)
return rewards.mean(), episodes
else:
rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))(
state, keys, act_func, params
)
return rewards.mean()
def evaluate_once(self, state, randkey, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey) rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset) init_obs, init_env_state = self.reset(rng_reset)