diff --git a/examples/brax/ant.py b/examples/brax/ant.py deleted file mode 100644 index 9b45e8c..0000000 --- a/examples/brax/ant.py +++ /dev/null @@ -1,39 +0,0 @@ -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import BraxEnv -from tensorneat.common import Act - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=27, - num_outputs=8, - max_nodes=100, - max_conns=200, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=Act.tanh, - ), - pop_size=1000, - species_size=10, - compatibility_threshold=3.5, - survival_threshold=0.01, - ), - ), - problem=BraxEnv( - env_name="ant", - ), - generation_limit=10000, - fitness_target=5000, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/brax/half_cheetah.py b/examples/brax/half_cheetah.py deleted file mode 100644 index 330ed97..0000000 --- a/examples/brax/half_cheetah.py +++ /dev/null @@ -1,48 +0,0 @@ -import jax - -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import BraxEnv -from tensorneat.common import Act - - -def sample_policy(randkey, obs): - return jax.random.uniform(randkey, (6,), minval=-1, maxval=1) - - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=17, - num_outputs=6, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=Act.tanh, - ), - pop_size=1000, - species_size=10, - ), - ), - problem=BraxEnv( - env_name="halfcheetah", - max_step=1000, - obs_normalization=True, - sample_episodes=1000, - sample_policy=sample_policy, - ), - generation_limit=10000, - fitness_target=5000, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/brax/halfcheetah.py b/examples/brax/halfcheetah.py new file mode 100644 index 0000000..86ef415 --- /dev/null +++ b/examples/brax/halfcheetah.py @@ -0,0 +1,51 @@ +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl import BraxEnv +from tensorneat.common import Act, Agg + +import jax + + +def random_sample_policy(randkey, obs): + return jax.random.uniform(randkey, (6,), minval=-1.0, maxval=1.0) + + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + max_nodes=100, + max_conns=200, + num_inputs=17, + num_outputs=6, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=Act.tanh, + aggregation_options=Agg.sum, + ), + output_transform=Act.standard_tanh, + ), + ), + problem=BraxEnv( + env_name="halfcheetah", + max_step=1000, + obs_normalization=True, + sample_episodes=1000, + sample_policy=random_sample_policy, + ), + seed=42, + generation_limit=100, + fitness_target=8000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/examples/brax/reacher.py b/examples/brax/reacher.py deleted file mode 100644 index bb331e5..0000000 --- a/examples/brax/reacher.py +++ /dev/null @@ -1,37 +0,0 @@ -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import BraxEnv -from tensorneat.common import Act - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=11, - num_outputs=2, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=Act.tanh, - ), - pop_size=100, - species_size=10, - ), - ), - problem=BraxEnv( - env_name="reacher", - ), - generation_limit=10000, - fitness_target=5000, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/brax/show_test.py b/examples/brax/show_test.py deleted file mode 100644 index c50920e..0000000 --- a/examples/brax/show_test.py +++ /dev/null @@ -1,19 +0,0 @@ -import jax -from problem.rl_env import BraxEnv - - -def random_policy(randkey, forward_func, obs): - return jax.random.uniform(randkey, (6,), minval=-1, maxval=1) - - -if __name__ == "__main__": - problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy) - state = problem.setup() - randkey = jax.random.key(0) - problem.show( - state, - randkey, - act_func=lambda state, params, obs: obs, - params=None, - save_path="walker2d_random_policy", - ) diff --git a/examples/brax/walker2d.py b/examples/brax/walker2d.py index 3f03168..6c42ce0 100644 --- a/examples/brax/walker2d.py +++ b/examples/brax/walker2d.py @@ -9,7 +9,7 @@ import jax, jax.numpy as jnp def random_sample_policy(randkey, obs): - return jax.random.uniform(randkey, (6,)) + return jax.random.uniform(randkey, (6,), minval=-1.0, maxval=1.0) if __name__ == "__main__": diff --git a/examples/gymnax/arcbot.py b/examples/gymnax/arcbot.py index dc0bd4d..03819f8 100644 --- a/examples/gymnax/arcbot.py +++ b/examples/gymnax/arcbot.py @@ -1,36 +1,45 @@ import jax.numpy as jnp -from pipeline import Pipeline -from algorithm.neat import * +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl import GymNaxEnv +from tensorneat.common import Act, Agg + -from problem.rl_env import GymNaxEnv if __name__ == "__main__": + # the network has 3 outputs, the max one will be the action + # as the action of acrobot is {0, 1, 2} + pipeline = Pipeline( algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=6, - num_outputs=3, - max_nodes=50, - max_conns=100, - output_transform=lambda out: jnp.argmax( - out - ), # the action of acrobot is {0, 1, 2} + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=6, + num_outputs=3, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=Act.tanh, + aggregation_options=Agg.sum, ), - pop_size=10000, - species_size=10, + output_transform=jnp.argmax, ), ), problem=GymNaxEnv( env_name="Acrobot-v1", ), - generation_limit=10000, - fitness_target=-62, + seed=42, + generation_limit=100, + fitness_target=-60, ) # initialize state state = pipeline.setup() - # print(state) + # run until terminate state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index a479a2a..f256ffc 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -1,41 +1,46 @@ import jax.numpy as jnp -from pipeline import Pipeline -from algorithm.neat import * +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode -from problem.rl_env import GymNaxEnv +from tensorneat.problem.rl import GymNaxEnv +from tensorneat.common import Act, Agg -def action_policy(randkey, forward_func, obs): - return jnp.argmax(forward_func(obs)) - if __name__ == "__main__": + # the network has 2 outputs, the max one will be the action + # as the action of cartpole is {0, 1} + pipeline = Pipeline( algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=4, - num_outputs=2, - max_nodes=50, - max_conns=100, - # output_transform=lambda out: jnp.argmax( - # out - # ), # the action of cartpole is {0, 1} + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=4, + num_outputs=2, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=Act.tanh, + aggregation_options=Agg.sum, ), - pop_size=10000, - species_size=10, + output_transform=jnp.argmax, ), ), problem=GymNaxEnv( - env_name="CartPole-v1", repeat_times=5, action_policy=action_policy + env_name="CartPole-v1", + repeat_times=5, ), - generation_limit=10000, + seed=42, + generation_limit=100, fitness_target=500, ) # initialize state state = pipeline.setup() - # print(state) + # run until terminate state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/cartpole_hyperneat.py b/examples/gymnax/cartpole_hyperneat.py index 30c92e2..786dcc1 100644 --- a/examples/gymnax/cartpole_hyperneat.py +++ b/examples/gymnax/cartpole_hyperneat.py @@ -1,70 +1,45 @@ -import jax +import jax.numpy as jnp -from pipeline import Pipeline -from algorithm.neat import * -from algorithm.hyperneat import * +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate +from tensorneat.genome import DefaultGenome from tensorneat.common import Act -from problem.rl_env import GymNaxEnv +from tensorneat.problem import GymNaxEnv if __name__ == "__main__": + + # the num of input_coors is 5 + # 4 is for cartpole inputs, 1 is for bias pipeline = Pipeline( algorithm=HyperNEAT( substrate=FullSubstrate( - input_coors=[ - (-1, -1), - (-0.5, -1), - (0, -1), - (0.5, -1), - (1, -1), - ], # 4(problem inputs) + 1(bias) - hidden_coors=[ - (-1, -0.5), - (0.333, -0.5), - (-0.333, -0.5), - (1, -0.5), - (-1, 0), - (0.333, 0), - (-0.333, 0), - (1, 0), - (-1, 0.5), - (0.333, 0.5), - (-0.333, 0.5), - (1, 0.5), - ], - output_coors=[ - (-1, 1), - (1, 1), # one output - ], + input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)), + hidden_coors=((-1, 0), (0, 0), (1, 0)), + output_coors=((-1, 1), (1, 1)), ), neat=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=4, # [*coor1, *coor2] - num_outputs=1, # the weight of connection between two coor1 and coor2 - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - output_transform=Act.tanh, # the activation function for output node in NEAT - ), - pop_size=10000, - species_size=10, - compatibility_threshold=3.5, - survival_threshold=0.03, + pop_size=10000, + species_size=20, + survival_threshold=0.01, + genome=DefaultGenome( + num_inputs=4, # size of query coors + num_outputs=1, + init_hidden_layers=(), + output_transform=Act.standard_tanh, ), ), - activation=Act.tanh, # the activation function for output node in HyperNEAT + activation=Act.tanh, activate_time=10, - output_transform=jax.numpy.argmax, # action of cartpole is in {0, 1} + output_transform=jnp.argmax, ), problem=GymNaxEnv( env_name="CartPole-v1", + repeat_times=5, ), generation_limit=300, - fitness_target=500, + fitness_target=-1e-6, ) # initialize state diff --git a/examples/gymnax/mountain_car.py b/examples/gymnax/mountain_car.py deleted file mode 100644 index d1ea062..0000000 --- a/examples/gymnax/mountain_car.py +++ /dev/null @@ -1,36 +0,0 @@ -import jax.numpy as jnp - -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import GymNaxEnv - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=2, - num_outputs=3, - max_nodes=50, - max_conns=100, - output_transform=lambda out: jnp.argmax( - out - ), # the action of mountain car is {0, 1, 2} - ), - pop_size=10000, - species_size=10, - ), - ), - problem=GymNaxEnv( - env_name="MountainCar-v0", - ), - generation_limit=10000, - fitness_target=-86, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py index 5420123..a6123b8 100644 --- a/examples/gymnax/mountain_car_continuous.py +++ b/examples/gymnax/mountain_car_continuous.py @@ -1,37 +1,43 @@ -from pipeline import Pipeline -from algorithm.neat import * +import jax.numpy as jnp + +from tensorneat.pipeline import Pipeline +from tensorneat.algorithm.neat import NEAT +from tensorneat.genome import DefaultGenome, BiasNode + +from tensorneat.problem.rl import GymNaxEnv +from tensorneat.common import Act, Agg + -from problem.rl_env import GymNaxEnv -from tensorneat.common import Act if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=2, - num_outputs=1, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=Act.tanh + pop_size=1000, + species_size=20, + survival_threshold=0.1, + compatibility_threshold=1.0, + genome=DefaultGenome( + num_inputs=2, + num_outputs=1, + init_hidden_layers=(), + node_gene=BiasNode( + activation_options=Act.tanh, + aggregation_options=Agg.sum, ), - pop_size=10000, - species_size=10, + output_transform=Act.standard_tanh, ), ), problem=GymNaxEnv( env_name="MountainCarContinuous-v0", + repeat_times=5, ), - generation_limit=10000, + seed=42, + generation_limit=100, fitness_target=99, ) # initialize state state = pipeline.setup() - # print(state) + # run until terminate state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/pendulum.py b/examples/gymnax/pendulum.py deleted file mode 100644 index d370394..0000000 --- a/examples/gymnax/pendulum.py +++ /dev/null @@ -1,38 +0,0 @@ -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import GymNaxEnv -from tensorneat.common import Act - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=3, - num_outputs=1, - max_nodes=50, - max_conns=100, - node_gene=DefaultNodeGene( - activation_options=(Act.tanh,), - activation_default=Act.tanh, - ), - output_transform=lambda out: Act.tanh(out) - * 2, # the action of pendulum is [-2, 2] - ), - pop_size=10000, - species_size=10, - ), - ), - problem=GymNaxEnv( - env_name="Pendulum-v1", - ), - generation_limit=10000, - fitness_target=-10, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/reacher.py b/examples/gymnax/reacher.py deleted file mode 100644 index 489357c..0000000 --- a/examples/gymnax/reacher.py +++ /dev/null @@ -1,33 +0,0 @@ -import jax.numpy as jnp - -from pipeline import Pipeline -from algorithm.neat import * - -from problem.rl_env import GymNaxEnv - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=8, - num_outputs=2, - max_nodes=50, - max_conns=100, - ), - pop_size=10000, - species_size=10, - ), - ), - problem=GymNaxEnv( - env_name="Reacher-misc", - ), - generation_limit=10000, - fitness_target=90, - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state) diff --git a/examples/jumanji/2048_random_policy.py b/examples/jumanji/2048_random_policy.py deleted file mode 100644 index 51927ec..0000000 --- a/examples/jumanji/2048_random_policy.py +++ /dev/null @@ -1,25 +0,0 @@ -import jax, jax.numpy as jnp -import jax.random -from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 - -def random_policy(state, params, obs): - key = jax.random.key(obs.sum()) - actions = jax.random.normal(key, (4,)) - # actions = actions.at[2:].set(-9999) - # return jnp.array([4, 4, 0, 1]) - # return jnp.array([1, 2, 3, 4]) - # return actions - return actions - - -if __name__ == "__main__": - problem = Jumanji_2048( - max_step=10000, repeat_times=1000, guarantee_invalid_action=False - ) - state = problem.setup() - jit_evaluate = jax.jit( - lambda state, randkey: problem.evaluate(state, randkey, random_policy, None) - ) - randkey = jax.random.PRNGKey(0) - reward = jit_evaluate(state, randkey) - print(reward) diff --git a/examples/jumanji/2048_test.ipynb b/examples/jumanji/2048_test.ipynb deleted file mode 100644 index d6b22a3..0000000 --- a/examples/jumanji/2048_test.ipynb +++ /dev/null @@ -1,1874 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-06-05T07:40:13.841629100Z", - "start_time": "2024-06-05T07:40:13.076164500Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initializing\n", - "initializing finished\n" - ] - } - ], - "source": [ - "import jax.numpy as jnp\n", - "\n", - "from pipeline import Pipeline\n", - "from algorithm.neat import *\n", - "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", - "\n", - "from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n", - "from tensorneat.utils import Act, Agg\n", - "\n", - "pipeline = Pipeline(\n", - " algorithm=NEAT(\n", - " species=DefaultSpecies(\n", - " genome=DefaultGenome(\n", - " num_inputs=16,\n", - " num_outputs=4,\n", - " max_nodes=100,\n", - " max_conns=1000,\n", - " node_gene=NodeGeneWithoutResponse(\n", - " activation_default=Act.sigmoid,\n", - " activation_options=(\n", - " Act.sigmoid,\n", - " Act.relu,\n", - " Act.tanh,\n", - " Act.identity,\n", - " ),\n", - " aggregation_default=Agg.sum,\n", - " aggregation_options=(Agg.sum,),\n", - " activation_replace_rate=0.02,\n", - " aggregation_replace_rate=0.02,\n", - " bias_mutate_rate=0.03,\n", - " bias_init_std=0.5,\n", - " bias_mutate_power=0.2,\n", - " bias_replace_rate=0.01,\n", - " ),\n", - " conn_gene=DefaultConnGene(\n", - " weight_mutate_rate=0.015,\n", - " weight_replace_rate=0.003,\n", - " weight_mutate_power=0.5,\n", - " ),\n", - " mutation=DefaultMutation(\n", - " node_add=0.1, conn_add=0.2, conn_delete=0.2\n", - " ),\n", - " ),\n", - " pop_size=1000,\n", - " species_size=5,\n", - " survival_threshold=0.1,\n", - " max_stagnation=7,\n", - " genome_elitism=3,\n", - " compatibility_threshold=1.2,\n", - " ),\n", - " ),\n", - " problem=Jumanji_2048(max_step=10000, repeat_times=5),\n", - " generation_limit=100,\n", - " fitness_target=13000,\n", - " save_path=\"2048.pkl\",\n", - ")\n", - "state = pipeline.setup()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "data = np.load('2048.npz')\n", - "nodes, conns = data['nodes'], data['conns']" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T07:40:13.932015100Z", - "start_time": "2024-06-05T07:40:13.876631500Z" - } - }, - "id": "a0915ecf8179f347" - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [], - "source": [ - "genome = pipeline.algorithm.species.genome\n", - "transformed = genome.transform(state, nodes, conns)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T07:40:14.585804800Z", - "start_time": "2024-06-05T07:40:14.568805Z" - } - }, - "id": "cd1fa65e8a9d6e13" - }, - { - "cell_type": "code", - "execution_count": 9, - "outputs": [], - "source": [ - "def policy(board):\n", - " action_scores = genome.forward(state, transformed, board)\n", - " return action_scores" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T07:40:15.124383600Z", - "start_time": "2024-06-05T07:40:15.118384200Z" - } - }, - "id": "61bc1895af304651" - }, - { - "cell_type": "code", - "execution_count": 14, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 1, 0],\n", - " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [1, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 1],\n", - " [1, 1, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [1, 1, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 1, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 2, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 2, 0, 0],\n", - " [0, 1, 0, 0],\n", - " [2, 2, 1, 1]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 1, 2],\n", - " [0, 0, 0, 1],\n", - " [0, 0, 3, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [0, 0, 0, 2],\n", - " [0, 0, 1, 1],\n", - " [0, 0, 3, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 2],\n", - " [0, 1, 2, 1],\n", - " [0, 0, 3, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 0, 0, 2],\n", - " [0, 0, 2, 1],\n", - " [0, 1, 3, 2]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [2, 0, 0, 0],\n", - " [2, 1, 0, 0],\n", - " [1, 3, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 2, 0, 0],\n", - " [1, 3, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [3, 2, 0, 0],\n", - " [1, 3, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 0, 1],\n", - " [3, 2, 0, 0],\n", - " [1, 3, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 1, 0, 0],\n", - " [3, 2, 0, 0],\n", - " [1, 3, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 1, 0, 0],\n", - " [3, 2, 1, 0],\n", - " [1, 3, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 1, 0, 0],\n", - " [3, 2, 1, 1],\n", - " [1, 3, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 2, 0, 0],\n", - " [3, 2, 1, 1],\n", - " [1, 3, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 3, 1, 2],\n", - " [1, 3, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 0],\n", - " [3, 0, 1, 1],\n", - " [1, 4, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 0],\n", - " [3, 0, 2, 1],\n", - " [1, 4, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 0, 1, 1],\n", - " [1, 4, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 0, 0, 0],\n", - " [3, 0, 1, 2],\n", - " [1, 4, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 0],\n", - " [3, 0, 1, 3],\n", - " [1, 4, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 0, 2, 0],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 0, 2, 1],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 0, 0, 0],\n", - " [3, 0, 2, 2],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 2],\n", - " [3, 0, 2, 2],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 1, 0],\n", - " [3, 0, 2, 3],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 0, 1, 0],\n", - " [3, 1, 2, 3],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 2, 0],\n", - " [3, 1, 2, 3],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 0, 1],\n", - " [3, 1, 3, 3],\n", - " [1, 4, 3, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 1],\n", - " [3, 2, 0, 3],\n", - " [1, 4, 4, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [1, 0, 0, 1],\n", - " [3, 2, 1, 3],\n", - " [1, 4, 4, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 0, 0, 1],\n", - " [3, 2, 1, 3],\n", - " [1, 4, 4, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [2, 1, 0, 2],\n", - " [3, 2, 1, 3],\n", - " [1, 4, 4, 4]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 2, 1, 2],\n", - " [3, 2, 1, 3],\n", - " [0, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 0, 2],\n", - " [1, 3, 2, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 1, 2],\n", - " [1, 3, 2, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 0, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 2, 2],\n", - " [2, 3, 2, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 1, 0, 2],\n", - " [2, 3, 3, 3],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 1],\n", - " [0, 0, 2, 2],\n", - " [0, 2, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [0, 0, 2, 2],\n", - " [0, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 2, 2],\n", - " [1, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 0, 2, 2],\n", - " [2, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [0, 1, 2, 2],\n", - " [2, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [1, 1, 2, 2],\n", - " [2, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [1, 2, 2, 2],\n", - " [2, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [2, 2, 2, 2],\n", - " [2, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [1, 2, 2, 2],\n", - " [3, 3, 3, 4],\n", - " [3, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 1],\n", - " [0, 2, 2, 2],\n", - " [1, 3, 3, 4],\n", - " [4, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [3, 2, 0, 2],\n", - " [1, 4, 4, 0],\n", - " [4, 1, 4, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [3, 2, 1, 0],\n", - " [1, 4, 0, 2],\n", - " [4, 1, 5, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [3, 2, 0, 0],\n", - " [1, 4, 1, 2],\n", - " [4, 1, 5, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [3, 2, 1, 1],\n", - " [1, 4, 1, 2],\n", - " [4, 1, 5, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 0],\n", - " [3, 2, 0, 1],\n", - " [1, 4, 2, 2],\n", - " [4, 1, 5, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 2, 0],\n", - " [3, 2, 1, 1],\n", - " [1, 4, 2, 2],\n", - " [4, 1, 5, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(76., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 2],\n", - " [1, 3, 2, 2],\n", - " [0, 1, 4, 3],\n", - " [0, 4, 1, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 1, 0],\n", - " [0, 3, 2, 3],\n", - " [0, 1, 4, 3],\n", - " [1, 4, 1, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 0],\n", - " [0, 3, 2, 0],\n", - " [0, 1, 4, 4],\n", - " [2, 4, 1, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [0, 3, 2, 2],\n", - " [0, 1, 4, 4],\n", - " [3, 4, 1, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(40., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 1],\n", - " [1, 0, 3, 3],\n", - " [0, 0, 1, 5],\n", - " [3, 4, 1, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 1, 2, 3],\n", - " [1, 0, 3, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 2, 2, 3],\n", - " [1, 1, 3, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 0, 3, 3],\n", - " [1, 2, 3, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [0, 0, 0, 3],\n", - " [1, 2, 4, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 1],\n", - " [0, 0, 1, 3],\n", - " [1, 2, 4, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [1, 3, 0, 1],\n", - " [1, 2, 4, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [2, 3, 0, 1],\n", - " [2, 2, 4, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [1, 3, 0, 1],\n", - " [3, 2, 4, 5],\n", - " [3, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 3, 1, 1],\n", - " [1, 2, 4, 5],\n", - " [4, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 3, 1, 1],\n", - " [2, 2, 4, 5],\n", - " [4, 4, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(44., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 1, 3, 2],\n", - " [0, 3, 4, 5],\n", - " [0, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 1, 3, 2],\n", - " [0, 3, 4, 5],\n", - " [1, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 2, 3, 2],\n", - " [0, 3, 4, 5],\n", - " [1, 5, 2, 6]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [0, 2, 3, 2],\n", - " [0, 3, 4, 5],\n", - " [2, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [2, 2, 3, 2],\n", - " [0, 3, 4, 5],\n", - " [2, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 2, 3, 3],\n", - " [1, 3, 4, 5],\n", - " [3, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 2, 4],\n", - " [1, 3, 4, 5],\n", - " [3, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 0, 2, 4],\n", - " [2, 3, 4, 5],\n", - " [3, 5, 2, 6]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [2, 4, 0, 0],\n", - " [2, 3, 4, 5],\n", - " [3, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 4, 1, 0],\n", - " [3, 3, 4, 5],\n", - " [3, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 4, 1, 0],\n", - " [1, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 4, 1, 0],\n", - " [2, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [0, 4, 1, 1],\n", - " [2, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 0],\n", - " [1, 4, 1, 1],\n", - " [2, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 0],\n", - " [1, 4, 2, 1],\n", - " [2, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 1],\n", - " [2, 4, 2, 1],\n", - " [2, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 0],\n", - " [0, 4, 2, 2],\n", - " [3, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 2, 0],\n", - " [1, 4, 2, 2],\n", - " [3, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [1, 4, 3, 2],\n", - " [3, 3, 4, 5],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [1, 4, 3, 2],\n", - " [4, 4, 5, 2],\n", - " [4, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(72., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [2, 1, 3, 0],\n", - " [1, 5, 5, 3],\n", - " [5, 5, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [2, 0, 3, 0],\n", - " [1, 1, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [2, 0, 3, 1],\n", - " [1, 1, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [2, 2, 3, 1],\n", - " [1, 1, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [2, 2, 3, 2],\n", - " [1, 1, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 3, 3, 2],\n", - " [0, 2, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [0, 3, 3, 3],\n", - " [1, 2, 5, 3],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [0, 3, 3, 0],\n", - " [1, 2, 5, 4],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [0, 0, 1, 4],\n", - " [1, 2, 5, 4],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 0, 1, 2],\n", - " [1, 2, 5, 5],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [0, 1, 1, 2],\n", - " [1, 2, 5, 5],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [2, 1, 1, 2],\n", - " [2, 2, 5, 5],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 1, 2],\n", - " [3, 2, 5, 5],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(68., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 2, 2],\n", - " [1, 3, 2, 6],\n", - " [5, 6, 2, 6]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(136., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 1, 0, 0],\n", - " [1, 3, 2, 2],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 1, 0, 1],\n", - " [1, 3, 2, 2],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 0, 2],\n", - " [1, 3, 2, 2],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 0, 1],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [0, 0, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 0, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 1, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 2, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [1, 2, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [2, 2, 1, 2],\n", - " [2, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 2, 2, 2],\n", - " [3, 3, 2, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [0, 2, 0, 2],\n", - " [3, 3, 3, 3],\n", - " [5, 6, 3, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 1],\n", - " [1, 2, 0, 2],\n", - " [3, 3, 0, 3],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [1, 2, 2, 2],\n", - " [3, 3, 1, 3],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 2],\n", - " [0, 1, 2, 3],\n", - " [0, 4, 1, 3],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 0],\n", - " [0, 2, 2, 2],\n", - " [0, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 2, 3, 2],\n", - " [1, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [2, 3, 2, 0],\n", - " [1, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 2, 3, 2],\n", - " [1, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 2, 3, 2],\n", - " [1, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 2, 3, 3],\n", - " [2, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [0, 1, 2, 4],\n", - " [2, 4, 1, 4],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 2, 1],\n", - " [2, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 2, 2, 1],\n", - " [2, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 2, 3, 1],\n", - " [2, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 2, 3, 2],\n", - " [2, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 0],\n", - " [2, 3, 2, 0],\n", - " [2, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 0],\n", - " [1, 3, 2, 0],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 0],\n", - " [2, 3, 2, 0],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 2],\n", - " [0, 2, 3, 2],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 2, 3, 3],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [1, 2, 4, 0],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [2, 2, 4, 0],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 0],\n", - " [3, 4, 0, 1],\n", - " [3, 4, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(48., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [2, 0, 1, 1],\n", - " [4, 5, 1, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [2, 0, 0, 1],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [2, 0, 1, 1],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 2, 2],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 1],\n", - " [1, 0, 0, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [1, 2, 0, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 1],\n", - " [2, 2, 0, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [3, 2, 0, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [3, 2, 1, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 2],\n", - " [3, 2, 1, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [3, 2, 2, 3],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [3, 3, 3, 0],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", - " [3, 3, 1, 0],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 4, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 0],\n", - " [3, 3, 0, 0],\n", - " [4, 5, 1, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 0],\n", - " [3, 3, 0, 0],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [3, 3, 1, 0],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 0],\n", - " [3, 3, 1, 1],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [3, 3, 2, 1],\n", - " [4, 5, 2, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 0],\n", - " [3, 3, 0, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [3, 3, 1, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 2],\n", - " [1, 4, 1, 2],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 4, 2, 3],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 1],\n", - " [1, 4, 2, 3],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 2, 0],\n", - " [1, 4, 2, 3],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 0],\n", - " [1, 4, 3, 3],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [1, 4, 1, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 2],\n", - " [1, 4, 1, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 0],\n", - " [1, 4, 1, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 1, 0],\n", - " [2, 4, 2, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 3, 1],\n", - " [2, 4, 2, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 1],\n", - " [2, 4, 2, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 3, 2],\n", - " [2, 4, 2, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 2, 1],\n", - " [2, 4, 2, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 0, 1],\n", - " [3, 4, 3, 3],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 3, 1],\n", - " [0, 3, 4, 4],\n", - " [4, 5, 4, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 1],\n", - " [1, 3, 3, 4],\n", - " [4, 5, 5, 5],\n", - " [5, 6, 5, 7]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 1],\n", - " [1, 3, 1, 4],\n", - " [4, 5, 3, 5],\n", - " [5, 6, 6, 7]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(128., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 1, 0],\n", - " [1, 3, 1, 4],\n", - " [4, 5, 3, 5],\n", - " [5, 7, 7, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 0, 0],\n", - " [2, 3, 2, 4],\n", - " [4, 5, 3, 5],\n", - " [5, 7, 7, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [3, 3, 2, 4],\n", - " [4, 5, 3, 5],\n", - " [5, 7, 7, 1]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(272., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 2, 0],\n", - " [4, 2, 4, 0],\n", - " [4, 5, 3, 5],\n", - " [5, 8, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 0],\n", - " [2, 2, 4, 0],\n", - " [5, 5, 3, 1],\n", - " [5, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 1],\n", - " [0, 2, 4, 0],\n", - " [2, 5, 3, 1],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 1],\n", - " [0, 2, 4, 0],\n", - " [2, 5, 3, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 0],\n", - " [1, 2, 4, 1],\n", - " [2, 5, 3, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 1],\n", - " [1, 2, 4, 1],\n", - " [2, 5, 3, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [2, 3, 4, 2],\n", - " [2, 5, 3, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 1],\n", - " [0, 3, 4, 0],\n", - " [3, 5, 3, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [0, 3, 4, 1],\n", - " [3, 5, 3, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 0],\n", - " [3, 4, 1, 0],\n", - " [3, 5, 4, 1],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 4, 1, 0],\n", - " [4, 5, 4, 1],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 4, 1, 0],\n", - " [4, 5, 4, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [2, 4, 1, 1],\n", - " [4, 5, 4, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [0, 2, 4, 2],\n", - " [4, 5, 4, 2],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(40., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 2, 0, 1],\n", - " [4, 5, 5, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 1, 2, 1],\n", - " [1, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 0],\n", - " [0, 1, 2, 2],\n", - " [1, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 3, 2],\n", - " [1, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [0, 1, 3, 2],\n", - " [2, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 1, 3, 2],\n", - " [2, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [2, 3, 2, 1],\n", - " [2, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 0],\n", - " [1, 3, 2, 1],\n", - " [3, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 3, 3, 1],\n", - " [3, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [1, 4, 1, 0],\n", - " [3, 4, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [2, 0, 1, 1],\n", - " [3, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 0, 2, 1],\n", - " [3, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 0, 2, 2],\n", - " [3, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 0],\n", - " [3, 2, 0, 0],\n", - " [3, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 2, 1, 0],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 2, 1, 2],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [1, 2, 1, 2],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [2, 2, 1, 2],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 1, 0],\n", - " [3, 1, 2, 0],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [3, 1, 2, 0],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [3, 2, 2, 1],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [3, 3, 1, 0],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 0],\n", - " [3, 3, 1, 1],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [3, 3, 2, 1],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [3, 3, 2, 2],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 0],\n", - " [4, 3, 0, 0],\n", - " [4, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [2, 3, 1, 0],\n", - " [5, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 0],\n", - " [2, 3, 2, 0],\n", - " [5, 5, 6, 3],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 0],\n", - " [2, 3, 2, 0],\n", - " [6, 6, 3, 1],\n", - " [6, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(136., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [0, 3, 2, 1],\n", - " [3, 6, 3, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [0, 3, 2, 0],\n", - " [3, 6, 3, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 1],\n", - " [3, 2, 0, 0],\n", - " [3, 6, 3, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [2, 2, 0, 1],\n", - " [4, 6, 3, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [2, 2, 0, 2],\n", - " [4, 6, 3, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [2, 2, 1, 0],\n", - " [4, 6, 3, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 3, 1],\n", - " [0, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 3, 2],\n", - " [1, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [1, 3, 2, 0],\n", - " [1, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [1, 3, 2, 0],\n", - " [2, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [2, 3, 2, 1],\n", - " [2, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 3, 2, 1],\n", - " [3, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [0, 3, 2, 2],\n", - " [3, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 0],\n", - " [1, 3, 2, 2],\n", - " [3, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 2],\n", - " [0, 1, 3, 3],\n", - " [3, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 0],\n", - " [1, 4, 1, 0],\n", - " [3, 4, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [2, 2, 1, 0],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [0, 0, 3, 1],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 0, 3, 2],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 0],\n", - " [1, 3, 2, 0],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 1],\n", - " [2, 3, 2, 0],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 0],\n", - " [2, 3, 2, 1],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 2],\n", - " [2, 3, 2, 1],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 0, 1],\n", - " [2, 3, 2, 1],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 1],\n", - " [3, 3, 2, 2],\n", - " [3, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 1],\n", - " [0, 3, 2, 2],\n", - " [4, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 0, 0],\n", - " [3, 3, 1, 0],\n", - " [4, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 3],\n", - " [0, 0, 4, 1],\n", - " [4, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 0, 0],\n", - " [4, 1, 1, 0],\n", - " [4, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 1, 0],\n", - " [1, 1, 1, 0],\n", - " [5, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 0, 1],\n", - " [1, 1, 2, 0],\n", - " [5, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 0, 1],\n", - " [1, 1, 2, 1],\n", - " [5, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 0, 0],\n", - " [1, 1, 2, 2],\n", - " [5, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 0, 1],\n", - " [2, 1, 2, 2],\n", - " [5, 5, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(72., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 3, 1],\n", - " [0, 2, 1, 3],\n", - " [0, 6, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 3, 1],\n", - " [2, 2, 1, 3],\n", - " [1, 6, 6, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(136., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 1, 0, 1],\n", - " [3, 1, 3, 0],\n", - " [1, 7, 4, 0],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [4, 2, 3, 0],\n", - " [1, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [0, 4, 2, 3],\n", - " [1, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 1],\n", - " [0, 4, 2, 3],\n", - " [2, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [1, 4, 2, 3],\n", - " [2, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 2],\n", - " [1, 4, 2, 3],\n", - " [2, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 2],\n", - " [1, 4, 2, 3],\n", - " [2, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 0, 2],\n", - " [2, 4, 2, 3],\n", - " [2, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 1, 2],\n", - " [1, 4, 2, 3],\n", - " [3, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 2, 1],\n", - " [1, 4, 2, 3],\n", - " [3, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 1],\n", - " [1, 4, 3, 3],\n", - " [3, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 1],\n", - " [1, 4, 4, 0],\n", - " [3, 7, 4, 1],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 2, 0],\n", - " [1, 4, 1, 0],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 0],\n", - " [1, 4, 1, 0],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 0, 0],\n", - " [1, 4, 2, 1],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 3, 2],\n", - " [1, 4, 2, 1],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 3, 2],\n", - " [2, 4, 2, 1],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 2, 1],\n", - " [2, 4, 2, 1],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 0],\n", - " [2, 4, 3, 2],\n", - " [3, 7, 5, 2],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 1],\n", - " [2, 4, 3, 0],\n", - " [3, 7, 5, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 1],\n", - " [2, 4, 3, 1],\n", - " [3, 7, 5, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 1],\n", - " [2, 4, 3, 2],\n", - " [3, 7, 5, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 3, 2],\n", - " [2, 4, 3, 2],\n", - " [3, 7, 5, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [2, 4, 4, 3],\n", - " [3, 7, 5, 3],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [2, 4, 4, 1],\n", - " [3, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 1],\n", - " [2, 4, 4, 2],\n", - " [3, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 0, 1],\n", - " [2, 5, 2, 0],\n", - " [3, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [3, 5, 2, 1],\n", - " [3, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 0],\n", - " [0, 5, 2, 2],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 0, 0],\n", - " [5, 3, 0, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 5, 3, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 1, 0],\n", - " [1, 5, 3, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 0],\n", - " [1, 5, 3, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 0],\n", - " [1, 5, 3, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 1, 0, 1],\n", - " [1, 5, 3, 1],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 1, 1, 0],\n", - " [1, 5, 3, 2],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 0],\n", - " [1, 5, 3, 2],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, True, False, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 2, 1],\n", - " [1, 5, 3, 2],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "TimeStep(step_type=Array(2, dtype=int8), reward=Array(4., dtype=float32), discount=Array(0., dtype=float32), observation=Observation(board=Array([[1, 3, 2, 1],\n", - " [2, 5, 3, 2],\n", - " [4, 7, 5, 4],\n", - " [7, 8, 1, 5]], dtype=int32), action_mask=Array([False, False, False, False], dtype=bool)), extras={'highest_tile': Array(256, dtype=int32)})\n", - "3716.0\n" - ] - } - ], - "source": [ - "import jax, jumanji\n", - "\n", - "env = jumanji.make(\"Game2048-v1\")\n", - "key = jax.random.PRNGKey(0)\n", - "jit_reset = jax.jit(env.reset)\n", - "jit_step = jax.jit(env.step)\n", - "state, timestep = jax.jit(env.reset)(key)\n", - "jit_policy = jax.jit(policy)\n", - "total_reward = 0\n", - "while True:\n", - " board, action_mask = timestep[\"observation\"]\n", - " action = jit_policy(timestep[\"observation\"][0].reshape(-1))\n", - " score_with_mask = jnp.where(action_mask, action, -jnp.inf)\n", - " action = jnp.argmax(score_with_mask)\n", - " state, timestep = jit_step(state, action)\n", - " done = jnp.all(~timestep[\"observation\"][1])\n", - " print(timestep)\n", - " total_reward += timestep[\"reward\"]\n", - " if done:\n", - " break\n", - "print(total_reward)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T07:41:33.703431900Z", - "start_time": "2024-06-05T07:41:26.102578200Z" - } - }, - "id": "f166e09c5be1a8fb" - }, - { - "cell_type": "code", - "execution_count": 17, - "outputs": [], - "source": [ - "import jax.random\n", - "from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n", - "\n", - "\n", - "def random_policy(state, params, obs):\n", - " key = jax.random.key(obs.sum())\n", - " actions = jax.random.normal(key, (4,))\n", - " return actions\n", - "\n", - "problem = Jumanji_2048(max_step=10000, repeat_times=10, guarantee_invalid_action=True)\n", - "state = problem.setup()\n", - "jit_evaluate = jax.jit(lambda state, randkey: problem.evaluate(state, randkey, random_policy, None))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T08:06:59.491563700Z", - "start_time": "2024-06-05T08:06:59.465404900Z" - } - }, - "id": "187326d08ac1eeb4" - }, - { - "cell_type": "code", - "execution_count": 24, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1193.2001\n" - ] - } - ], - "source": [ - "\n", - "reward = jit_evaluate(state, randkey)\n", - "print(reward)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T08:07:21.630420300Z", - "start_time": "2024-06-05T08:07:21.107419400Z" - } - }, - "id": "4b3506db87568d81" - }, - { - "cell_type": "code", - "execution_count": 34, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 1],\n", - " [1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [1, 1, 1, 1]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 2, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [3, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [3, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [3, 2, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 0, 1, 0],\n", - " [3, 2, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 1, 0, 0],\n", - " [3, 2, 0, 1]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 0, 0],\n", - " [0, 0, 0, 0],\n", - " [2, 1, 0, 0],\n", - " [3, 2, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 0],\n", - " [2, 2, 0, 0],\n", - " [3, 0, 0, 1],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 1, 1],\n", - " [2, 2, 0, 0],\n", - " [3, 0, 0, 0],\n", - " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [1, 1, 0, 0],\n", - " [2, 2, 0, 0],\n", - " [3, 1, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 1, 1],\n", - " [2, 2, 0, 0],\n", - " [3, 1, 0, 2],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 0, 0, 0],\n", - " [2, 3, 1, 1],\n", - " [3, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 0, 0],\n", - " [2, 3, 0, 1],\n", - " [3, 1, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 0, 2, 1],\n", - " [0, 2, 3, 1],\n", - " [0, 3, 1, 3]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 2, 1],\n", - " [0, 2, 3, 2],\n", - " [1, 3, 1, 3]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 2, 1],\n", - " [0, 3, 3, 2],\n", - " [1, 0, 1, 3],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 2, 1],\n", - " [1, 2, 3, 2],\n", - " [2, 3, 1, 3]], dtype=int32), action_mask=Array([ True, False, False, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [2, 1, 1, 0],\n", - " [1, 2, 3, 2],\n", - " [2, 3, 1, 3]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 0, 2, 2],\n", - " [1, 2, 3, 2],\n", - " [2, 3, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 2, 1],\n", - " [2, 3, 3, 3],\n", - " [1, 0, 1, 3],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 2, 1],\n", - " [2, 3, 3, 4],\n", - " [1, 0, 1, 0],\n", - " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 3, 1],\n", - " [0, 2, 4, 4],\n", - " [0, 0, 0, 2],\n", - " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [1, 0, 0, 4],\n", - " [0, 1, 3, 2],\n", - " [1, 2, 4, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 1, 0],\n", - " [1, 4, 0, 0],\n", - " [1, 3, 2, 0],\n", - " [1, 2, 4, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 2],\n", - " [0, 0, 1, 4],\n", - " [0, 1, 3, 2],\n", - " [1, 2, 4, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 2, 2],\n", - " [0, 2, 3, 4],\n", - " [0, 0, 4, 2],\n", - " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", - " [1, 0, 2, 4],\n", - " [0, 1, 3, 2],\n", - " [2, 2, 4, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 2, 2],\n", - " [2, 2, 3, 4],\n", - " [0, 1, 4, 2],\n", - " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 2, 3],\n", - " [0, 3, 3, 4],\n", - " [1, 1, 4, 2],\n", - " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 2, 3],\n", - " [0, 1, 3, 4],\n", - " [0, 0, 4, 2],\n", - " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 2, 3],\n", - " [1, 1, 3, 4],\n", - " [0, 0, 4, 2],\n", - " [0, 0, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 2, 3],\n", - " [0, 1, 3, 4],\n", - " [0, 1, 4, 3],\n", - " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 2, 3],\n", - " [0, 2, 3, 4],\n", - " [0, 0, 4, 3],\n", - " [0, 0, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 2, 3],\n", - " [0, 2, 3, 4],\n", - " [2, 0, 4, 3],\n", - " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 3, 2, 3],\n", - " [0, 2, 3, 4],\n", - " [0, 0, 4, 3],\n", - " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 3, 2, 3],\n", - " [0, 2, 3, 4],\n", - " [0, 1, 4, 3],\n", - " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [1, 4, 3, 0],\n", - " [2, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, False], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [0, 2, 3, 4],\n", - " [0, 1, 4, 3],\n", - " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [1, 4, 3, 0],\n", - " [1, 2, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [2, 4, 3, 1],\n", - " [0, 2, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [2, 4, 3, 1],\n", - " [2, 1, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [2, 4, 3, 1],\n", - " [2, 2, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [0, 2, 3, 4],\n", - " [2, 4, 3, 1],\n", - " [0, 1, 2, 3]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [2, 3, 4, 0],\n", - " [2, 4, 3, 1],\n", - " [1, 2, 3, 1]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 3, 1],\n", - " [1, 2, 3, 4],\n", - " [2, 4, 3, 1],\n", - " [1, 2, 3, 1]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(44., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 0, 0, 1],\n", - " [1, 3, 0, 1],\n", - " [2, 4, 4, 4],\n", - " [1, 2, 4, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 1, 0, 0],\n", - " [1, 3, 1, 0],\n", - " [2, 5, 4, 1],\n", - " [1, 2, 4, 2]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 1, 1, 1],\n", - " [1, 3, 5, 2],\n", - " [2, 5, 0, 1],\n", - " [1, 2, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 4, 1, 2],\n", - " [1, 3, 5, 2],\n", - " [1, 2, 5, 1],\n", - " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 1, 2, 0],\n", - " [1, 3, 5, 2],\n", - " [1, 2, 5, 1],\n", - " [1, 2, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(80., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 1, 2, 2],\n", - " [2, 3, 6, 2],\n", - " [1, 3, 0, 0],\n", - " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 4, 1, 3],\n", - " [2, 3, 6, 2],\n", - " [0, 0, 1, 3],\n", - " [0, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 3],\n", - " [2, 3, 6, 2],\n", - " [0, 0, 1, 3],\n", - " [0, 0, 0, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 3],\n", - " [0, 3, 6, 2],\n", - " [0, 0, 1, 3],\n", - " [1, 0, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 3],\n", - " [1, 3, 6, 2],\n", - " [0, 0, 1, 3],\n", - " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 3],\n", - " [1, 3, 6, 2],\n", - " [1, 3, 1, 0],\n", - " [1, 2, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 3],\n", - " [2, 4, 6, 2],\n", - " [1, 2, 1, 0],\n", - " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 5, 1, 3],\n", - " [2, 2, 6, 2],\n", - " [2, 0, 1, 0],\n", - " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 1, 1, 0],\n", - " [3, 5, 6, 3],\n", - " [3, 2, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [1, 1, 1, 0],\n", - " [0, 5, 6, 3],\n", - " [4, 2, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 0],\n", - " [0, 0, 1, 2],\n", - " [1, 5, 6, 3],\n", - " [0, 4, 2, 3]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", - " [1, 2, 0, 0],\n", - " [1, 5, 6, 3],\n", - " [4, 2, 3, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 1],\n", - " [0, 2, 1, 0],\n", - " [2, 5, 6, 0],\n", - " [4, 2, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 1],\n", - " [4, 5, 6, 3],\n", - " [0, 2, 3, 0],\n", - " [0, 0, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 0, 0],\n", - " [4, 5, 6, 3],\n", - " [2, 3, 1, 0],\n", - " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [4, 5, 1, 0],\n", - " [2, 3, 1, 0],\n", - " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [4, 5, 2, 0],\n", - " [2, 3, 0, 0],\n", - " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [4, 5, 2, 0],\n", - " [2, 3, 1, 0],\n", - " [2, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [1, 4, 5, 2],\n", - " [0, 2, 3, 1],\n", - " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [1, 4, 5, 2],\n", - " [2, 3, 1, 1],\n", - " [2, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [1, 4, 5, 2],\n", - " [1, 2, 3, 2],\n", - " [0, 0, 0, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 6, 3],\n", - " [1, 4, 5, 2],\n", - " [1, 2, 3, 2],\n", - " [2, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 6, 0],\n", - " [3, 2, 5, 1],\n", - " [2, 4, 3, 3],\n", - " [2, 2, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 6, 0],\n", - " [0, 2, 5, 1],\n", - " [3, 4, 3, 1],\n", - " [3, 2, 1, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 6],\n", - " [1, 2, 5, 1],\n", - " [3, 4, 3, 1],\n", - " [3, 2, 1, 4]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[6, 0, 0, 1],\n", - " [1, 2, 5, 1],\n", - " [3, 4, 3, 1],\n", - " [3, 2, 1, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 0, 0],\n", - " [6, 2, 5, 1],\n", - " [1, 4, 3, 2],\n", - " [4, 2, 1, 4]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 0, 2, 0],\n", - " [6, 2, 5, 1],\n", - " [1, 4, 3, 2],\n", - " [4, 2, 1, 4]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 2, 1],\n", - " [6, 4, 5, 2],\n", - " [1, 2, 3, 4],\n", - " [4, 1, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 3, 1],\n", - " [6, 4, 5, 2],\n", - " [1, 2, 3, 4],\n", - " [0, 0, 4, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 3, 1],\n", - " [6, 4, 5, 2],\n", - " [1, 2, 3, 4],\n", - " [0, 1, 4, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 1],\n", - " [6, 4, 5, 2],\n", - " [1, 2, 3, 4],\n", - " [1, 4, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 0],\n", - " [2, 4, 5, 1],\n", - " [6, 2, 3, 2],\n", - " [2, 4, 2, 4]], dtype=int32), action_mask=Array([ True, True, False, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 3, 1],\n", - " [2, 4, 5, 1],\n", - " [6, 2, 3, 2],\n", - " [2, 4, 2, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 3, 1],\n", - " [2, 4, 5, 1],\n", - " [6, 2, 3, 2],\n", - " [2, 4, 2, 4]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 2, 3, 1],\n", - " [2, 4, 5, 2],\n", - " [6, 2, 3, 2],\n", - " [2, 4, 2, 4]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "TimeStep(step_type=Array(2, dtype=int8), reward=Array(8., dtype=float32), discount=Array(0., dtype=float32), observation=Observation(board=Array([[1, 2, 3, 1],\n", - " [2, 4, 5, 3],\n", - " [6, 2, 3, 4],\n", - " [2, 4, 2, 1]], dtype=int32), action_mask=Array([False, False, False, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", - "636.0\n" - ] - } - ], - "source": [ - "randkey = jax.random.PRNGKey(14)\n", - "jit_policy = jax.jit(random_policy)\n", - "total_reward = 0\n", - "state, timestep = jax.jit(env.reset)(randkey )\n", - "while True:\n", - " board, action_mask = timestep[\"observation\"]\n", - " action = jit_policy(None, None, timestep[\"observation\"][0].reshape(-1))\n", - " score_with_mask = jnp.where(action_mask, action, -jnp.inf)\n", - " action = jnp.argmax(score_with_mask)\n", - " state, timestep = jit_step(state, action)\n", - " done = jnp.all(~timestep[\"observation\"][1])\n", - " print(timestep)\n", - " total_reward += timestep[\"reward\"]\n", - " if done:\n", - " break\n", - "print(total_reward)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-06-05T08:09:58.242414600Z", - "start_time": "2024-06-05T08:09:56.452642800Z" - } - }, - "id": "8bb888fb742b6b06" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "3d1b5c8c646d4f07" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/jumanji/train_2048.py b/examples/jumanji/train_2048.py deleted file mode 100644 index e2efeb4..0000000 --- a/examples/jumanji/train_2048.py +++ /dev/null @@ -1,120 +0,0 @@ -import jax, jax.numpy as jnp - -from pipeline import Pipeline -from algorithm.neat import * -from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse -from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 -from tensorneat.common import Act, Agg - - -def rot_li(li): - return li[1:] + [li[0]] - - -def rot_boards(board): - def rot(a, _): - a = jnp.rot90(a) - return a, a # carry, y - - # carry, np.stack(ys) - _, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32)) - return boards - - -direction = ["up", "right", "down", "left"] -lr_flip_direction = ["up", "left", "down", "right"] - -directions = [] -lr_flip_directions = [] -for _ in range(4): - direction = rot_li(direction) - lr_flip_direction = rot_li(lr_flip_direction) - directions.append(direction.copy()) - lr_flip_directions.append(lr_flip_direction.copy()) - -full_directions = directions + lr_flip_directions - - -def action_policy(forward_func, obs): - board = obs.reshape(4, 4) - lr_flip_board = jnp.fliplr(board) - - boards = rot_boards(board) - lr_flip_boards = rot_boards(lr_flip_board) - # stack - full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0) - scores = jax.vmap(forward_func)(full_boards.reshape(8, -1)) - total_score = {"up": 0, "right": 0, "down": 0, "left": 0} - for i in range(8): - dire = full_directions[i] - for j in range(4): - total_score[dire[j]] += scores[i, j] - - return jnp.array( - [ - total_score["up"], - total_score["right"], - total_score["down"], - total_score["left"], - ] - ) - - -if __name__ == "__main__": - pipeline = Pipeline( - algorithm=NEAT( - species=DefaultSpecies( - genome=DefaultGenome( - num_inputs=16, - num_outputs=4, - max_nodes=100, - max_conns=1000, - node_gene=NodeGeneWithoutResponse( - activation_default=Act.sigmoid, - activation_options=( - Act.sigmoid, - Act.relu, - Act.tanh, - Act.identity, - ), - aggregation_default=Agg.sum, - aggregation_options=(Agg.sum, ), - activation_replace_rate=0.02, - aggregation_replace_rate=0.02, - bias_mutate_rate=0.03, - bias_init_std=0.5, - bias_mutate_power=0.02, - bias_replace_rate=0.01, - ), - conn_gene=DefaultConnGene( - weight_mutate_rate=0.015, - weight_replace_rate=0.03, - weight_mutate_power=0.05, - ), - mutation=DefaultMutation(node_add=0.001, conn_add=0.002), - ), - pop_size=1000, - species_size=5, - survival_threshold=0.01, - max_stagnation=7, - genome_elitism=3, - compatibility_threshold=1.2, - ), - ), - problem=Jumanji_2048( - max_step=1000, - repeat_times=50, - # guarantee_invalid_action=True, - guarantee_invalid_action=False, - action_policy=action_policy, - ), - generation_limit=10000, - fitness_target=13000, - save_path="2048.npz", - ) - - # initialize state - state = pipeline.setup() - # print(state) - # run until terminate - state, best = pipeline.auto_run(state)