Files
tensorneat-mend/tensorneat/examples/jumanji/2048_random_policy.py
wls2002 b3e442c688 add sympy support; which can transfer your network into sympy expression;
add visualize in genome;
add related tests.
2024-06-12 21:36:35 +08:00

26 lines
763 B
Python

import jax, jax.numpy as jnp
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs):
key = jax.random.key(obs.sum())
actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999)
# return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4])
# return actions
return actions
if __name__ == "__main__":
problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
)
state = problem.setup()
jit_evaluate = jax.jit(
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
)
randkey = jax.random.PRNGKey(0)
reward = jit_evaluate(state, randkey)
print(reward)