add sympy support; which can transfer your network into sympy expression;
add visualize in genome; add related tests.
This commit is contained in:
@@ -2,19 +2,19 @@ import jax, jax.numpy as jnp
|
||||
import jax.random
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
|
||||
|
||||
def random_policy(state, params, obs):
|
||||
# key = jax.random.key(obs.sum())
|
||||
# actions = jax.random.normal(key, (4,))
|
||||
key = jax.random.key(obs.sum())
|
||||
actions = jax.random.normal(key, (4,))
|
||||
# actions = actions.at[2:].set(-9999)
|
||||
return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([4, 4, 0, 1])
|
||||
# return jnp.array([1, 2, 3, 4])
|
||||
# return actions
|
||||
return actions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
problem = Jumanji_2048(
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=True
|
||||
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
|
||||
)
|
||||
state = problem.setup()
|
||||
jit_evaluate = jax.jit(
|
||||
|
||||
Reference in New Issue
Block a user