odify genome for the official release
This commit is contained in:
25
examples/jumanji/2048_random_policy.py
Normal file
25
examples/jumanji/2048_random_policy.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jax.random
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
|
||||
def random_policy(state, params, obs):
|
||||
key = jax.random.key(obs.sum())
|
||||
actions = jax.random.normal(key, (4,))
|
||||
# actions = actions.at[2:].set(-9999)
|
||||
# return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([1, 2, 3, 4])
|
||||
# return actions
|
||||
return actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = Jumanji_2048(
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
||||
)
|
||||
state = problem.setup()
|
||||
jit_evaluate = jax.jit(
|
||||
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
|
||||
)
|
||||
randkey = jax.random.PRNGKey(0)
|
||||
reward = jit_evaluate(state, randkey)
|
||||
print(reward)
|
||||
Reference in New Issue
Block a user