From 10ec1c2df9addff1e64eeb52df22c90da1e40471 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 5 Jun 2024 14:24:17 +0800 Subject: [PATCH] add jumanji env; add repeat times for rl_env --- .../algorithm/neat/gene/node/min_max_node.py | 193 +++ tensorneat/examples/gymnax/cartpole.py | 1 + tensorneat/examples/jumanji/2048.py | 46 + tensorneat/examples/jumanji/2048_test.ipynb | 1285 +++++++++++++++++ tensorneat/problem/func_fit/func_fit.py | 2 +- tensorneat/problem/rl_env/brax_env.py | 4 +- tensorneat/problem/rl_env/gymnax_env.py | 4 +- tensorneat/problem/rl_env/jumanji/__init__.py | 0 .../problem/rl_env/jumanji/jumanji_2048.py | 56 + tensorneat/problem/rl_env/rl_jit.py | 31 +- 10 files changed, 1615 insertions(+), 7 deletions(-) create mode 100644 tensorneat/algorithm/neat/gene/node/min_max_node.py create mode 100644 tensorneat/examples/jumanji/2048.py create mode 100644 tensorneat/examples/jumanji/2048_test.ipynb create mode 100644 tensorneat/problem/rl_env/jumanji/__init__.py create mode 100644 tensorneat/problem/rl_env/jumanji/jumanji_2048.py diff --git a/tensorneat/algorithm/neat/gene/node/min_max_node.py b/tensorneat/algorithm/neat/gene/node/min_max_node.py new file mode 100644 index 0000000..caf270c --- /dev/null +++ b/tensorneat/algorithm/neat/gene/node/min_max_node.py @@ -0,0 +1,193 @@ +from typing import Tuple + +import jax, jax.numpy as jnp + +from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float +from . import BaseNodeGene + + +class MinMaxNode(BaseNodeGene): + """ + Node with normalization before activation. + """ + + # alpha and beta is used for normalization, just like BatchNorm + # norm: z = act(agg(inputs) + bias) + # z = (z - min) / (max - min) * (max_out - min_out) + min_out + custom_attrs = ["bias", "aggregation", "activation", "min", "max"] + eps = 1e-6 + + def __init__( + self, + bias_init_mean: float = 0.0, + bias_init_std: float = 1.0, + bias_mutate_power: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + aggregation_default: callable = Agg.sum, + aggregation_options: Tuple = (Agg.sum,), + aggregation_replace_rate: float = 0.1, + activation_default: callable = Act.sigmoid, + activation_options: Tuple = (Act.sigmoid,), + activation_replace_rate: float = 0.1, + output_range: Tuple[float, float] = (-1, 1), + update_hidden_node: bool = False, + ): + super().__init__() + self.bias_init_mean = bias_init_mean + self.bias_init_std = bias_init_std + self.bias_mutate_power = bias_mutate_power + self.bias_mutate_rate = bias_mutate_rate + self.bias_replace_rate = bias_replace_rate + + self.aggregation_default = aggregation_options.index(aggregation_default) + self.aggregation_options = aggregation_options + self.aggregation_indices = jnp.arange(len(aggregation_options)) + self.aggregation_replace_rate = aggregation_replace_rate + + self.activation_default = activation_options.index(activation_default) + self.activation_options = activation_options + self.activation_indices = jnp.arange(len(activation_options)) + self.activation_replace_rate = activation_replace_rate + + self.output_range = output_range + assert ( + len(self.output_range) == 2 and self.output_range[0] < self.output_range[1] + ) + self.update_hidden_node = update_hidden_node + + def new_identity_attrs(self, state): + return jnp.array( + [0, self.aggregation_default, -1, 0, 1] + ) # activation=-1 means Act.identity; min=0, max=1 will do not influence + + def new_random_attrs(self, state, randkey): + k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) + bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean + agg = jax.random.randint(k2, (), 0, len(self.aggregation_options)) + act = jax.random.randint(k3, (), 0, len(self.activation_options)) + return jnp.array([bias, agg, act, 0, 1]) + + def mutate(self, state, randkey, attrs): + k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) + bias, act, agg, min_, max_ = attrs + + bias = mutate_float( + k1, + bias, + self.bias_init_mean, + self.bias_init_std, + self.bias_mutate_power, + self.bias_mutate_rate, + self.bias_replace_rate, + ) + + agg = mutate_int( + k2, agg, self.aggregation_indices, self.aggregation_replace_rate + ) + + act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate) + + return jnp.array([bias, agg, act, min_, max_]) + + def distance(self, state, attrs1, attrs2): + bias1, agg1, act1, min1, max1 = attrs1 + bias2, agg2, act2, min1, max1 = attrs2 + return ( + jnp.abs(bias1 - bias2) # bias + + (agg1 != agg2) # aggregation + + (act1 != act2) # activation + ) + + def forward(self, state, attrs, inputs, is_output_node=False): + """ + post_act = (agg(inputs) + bias - mean) / std * alpha + beta + """ + bias, agg, act, min_, max_ = attrs + + z = agg_func(agg, inputs, self.aggregation_options) + z = bias + z + + # the last output node should not be activated + z = jax.lax.cond( + is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options) + ) + + if self.update_hidden_node: + z = (z - min_) / (max_ - min_) # transform to 01 + z = ( + z * (self.output_range[1] - self.output_range[0]) + self.output_range[0] + ) # transform to output_range + + return z + + def input_transform(self, state, attrs, inputs): + """ + make transform in the input node. + the normalization also need be done in the first node. + """ + bias, agg, act, min_, max_ = attrs + inputs = (inputs - min_) / (max_ - min_) # transform to 01 + inputs = ( + inputs * (self.output_range[1] - self.output_range[0]) + + self.output_range[0] + ) + return inputs + + def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): + + bias, agg, act, min_, max_ = attrs + + batch_z = jax.vmap(agg_func, in_axes=(None, 0, None))( + agg, batch_inputs, self.aggregation_options + ) + + batch_z = bias + batch_z + + batch_z = jax.lax.cond( + is_output_node, + lambda: batch_z, + lambda: jax.vmap(act_func, in_axes=(None, 0, None))( + act, batch_z, self.activation_options + ), + ) + + if self.update_hidden_node: + # calculate min, max + min_ = jnp.min(jnp.where(jnp.isnan(batch_z), jnp.inf, batch_z)) + max_ = jnp.max(jnp.where(jnp.isnan(batch_z), -jnp.inf, batch_z)) + + batch_z = (batch_z - min_) / (max_ - min_) # transform to 01 + batch_z = ( + batch_z * (self.output_range[1] - self.output_range[0]) + + self.output_range[0] + ) + + # update mean and std to the attrs + attrs = attrs.at[3].set(min_) + attrs = attrs.at[4].set(max_) + + return batch_z, attrs + + def update_input_transform(self, state, attrs, batch_inputs): + """ + update the attrs for transformation in the input node. + default: do nothing + """ + bias, agg, act, min_, max_ = attrs + + # calculate min, max + min_ = jnp.min(jnp.where(jnp.isnan(batch_inputs), jnp.inf, batch_inputs)) + max_ = jnp.max(jnp.where(jnp.isnan(batch_inputs), -jnp.inf, batch_inputs)) + + batch_inputs = (batch_inputs - min_) / (max_ - min_) # transform to 01 + batch_inputs = ( + batch_inputs * (self.output_range[1] - self.output_range[0]) + + self.output_range[0] + ) + + # update mean and std to the attrs + attrs = attrs.at[3].set(min_) + attrs = attrs.at[4].set(max_) + + return batch_inputs, attrs diff --git a/tensorneat/examples/gymnax/cartpole.py b/tensorneat/examples/gymnax/cartpole.py index 0199e3d..16fcbe5 100644 --- a/tensorneat/examples/gymnax/cartpole.py +++ b/tensorneat/examples/gymnax/cartpole.py @@ -24,6 +24,7 @@ if __name__ == "__main__": ), problem=GymNaxEnv( env_name="CartPole-v1", + repeat_times=5 ), generation_limit=10000, fitness_target=500, diff --git a/tensorneat/examples/jumanji/2048.py b/tensorneat/examples/jumanji/2048.py new file mode 100644 index 0000000..39ecd00 --- /dev/null +++ b/tensorneat/examples/jumanji/2048.py @@ -0,0 +1,46 @@ +import jax.numpy as jnp + +from pipeline import Pipeline +from algorithm.neat import * + +from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 +from utils import Act, Agg + +if __name__ == "__main__": + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=16, + num_outputs=4, + max_nodes=100, + max_conns=1000, + node_gene=DefaultNodeGene( + activation_default=Act.sigmoid, + activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv), + aggregation_default=Agg.sum, + aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product), + ), + mutation=DefaultMutation( + node_add=0.03, + conn_add=0.03, + ) + ), + pop_size=10000, + species_size=100, + survival_threshold=0.01, + ), + ), + problem=Jumanji_2048( + max_step=10000, + repeat_times=5 + ), + generation_limit=10000, + fitness_target=13000, + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/jumanji/2048_test.ipynb b/tensorneat/examples/jumanji/2048_test.ipynb new file mode 100644 index 0000000..e779bfd --- /dev/null +++ b/tensorneat/examples/jumanji/2048_test.ipynb @@ -0,0 +1,1285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-06-05T05:07:22.736605400Z", + "start_time": "2024-06-05T05:06:39.100164300Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initializing\n", + "initializing finished\n", + "start compile\n", + "compile finished, cost time: 18.307454s\n", + "Generation: 1.0, Cost time: 4551.03ms\n", + " \tnode counts: max: 21, min: 21, mean: 21.00\n", + " \tconn counts: max: 20, min: 20, mean: 20.00\n", + " \tspecies: 1, [10000]\n", + " \tfitness: valid cnt: 10000, max: 10124.0000, min: 44.0000, mean: 1758.1263, std: 1212.6823\n", + "Generation: 2.0, Cost time: 4636.33ms\n", + " \tnode counts: max: 22, min: 21, mean: 21.03\n", + " \tconn counts: max: 22, min: 20, mean: 20.05\n", + " \tspecies: 1, [10000]\n", + " \tfitness: valid cnt: 10000, max: 11000.0000, min: 48.0000, mean: 1870.1300, std: 1263.3086\n", + "Generation: 3.0, Cost time: 6271.12ms\n", + " \tnode counts: max: 23, min: 21, mean: 21.03\n", + " \tconn counts: max: 22, min: 20, mean: 20.05\n", + " \tspecies: 1, [10000]\n", + " \tfitness: valid cnt: 10000, max: 14624.0000, min: 28.0000, mean: 1943.9924, std: 1293.7146\n", + "\n", + "Fitness limit reached!\n" + ] + } + ], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "from pipeline import Pipeline\n", + "from algorithm.neat import *\n", + "\n", + "from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n", + "from utils import Act, Agg\n", + "\n", + "if __name__ == \"__main__\":\n", + " pipeline = Pipeline(\n", + " algorithm=NEAT(\n", + " species=DefaultSpecies(\n", + " genome=DefaultGenome(\n", + " num_inputs=16,\n", + " num_outputs=4,\n", + " max_nodes=100,\n", + " max_conns=1000,\n", + " node_gene=DefaultNodeGene(\n", + " activation_default=Act.sigmoid,\n", + " activation_options=(Act.sigmoid, Act.relu, Act.tanh, Act.identity, Act.inv),\n", + " aggregation_default=Agg.sum,\n", + " aggregation_options=(Agg.sum, Agg.mean, Agg.max, Agg.product),\n", + " ),\n", + " mutation=DefaultMutation(\n", + " node_add=0.03,\n", + " conn_add=0.03,\n", + " )\n", + " ),\n", + " pop_size=10000,\n", + " species_size=100,\n", + " survival_threshold=0.01,\n", + " ),\n", + " ),\n", + " problem=Jumanji_2048(\n", + " max_step=1000,\n", + " ),\n", + " generation_limit=10000,\n", + " fitness_target=13000,\n", + " )\n", + "\n", + " # initialize state\n", + " state = pipeline.setup()\n", + " # print(state)\n", + " # run until terminate\n", + " state, best = pipeline.auto_run(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "genome = pipeline.algorithm.genome" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-05T05:08:14.332101Z", + "start_time": "2024-06-05T05:08:14.324101300Z" + } + }, + "id": "a0915ecf8179f347" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "transformed = genome.transform(state, *best)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-05T05:08:49.132030500Z", + "start_time": "2024-06-05T05:08:48.495809200Z" + } + }, + "id": "cd1fa65e8a9d6e13" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "def policy(board):\n", + " action_scores = genome.forward(state, transformed, board)\n", + " return action_scores" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-05T05:09:32.355055100Z", + "start_time": "2024-06-05T05:09:32.350057Z" + } + }, + "id": "61bc1895af304651" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 1, 0, 0],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(2, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 0, 2],\n", + " [0, 1, 0, 0],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 1, 2],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 1, 2, 2],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(4, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 3],\n", + " [0, 0, 0, 0],\n", + " [2, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 1, 3],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 0, 1, 3],\n", + " [0, 0, 0, 1],\n", + " [0, 0, 0, 0],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 3],\n", + " [0, 0, 0, 1],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 1, 1, 3],\n", + " [0, 0, 2, 1],\n", + " [0, 1, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 2, 1, 3],\n", + " [1, 0, 2, 1],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 1, 3],\n", + " [0, 1, 2, 1],\n", + " [0, 1, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 3, 1, 3],\n", + " [0, 2, 2, 1],\n", + " [1, 0, 0, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 3],\n", + " [0, 2, 2, 1],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 3],\n", + " [0, 2, 2, 1],\n", + " [0, 0, 1, 0],\n", + " [0, 0, 2, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 3],\n", + " [0, 0, 3, 1],\n", + " [0, 0, 1, 1],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[1, 3, 1, 3],\n", + " [0, 0, 3, 2],\n", + " [0, 0, 1, 2],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 3],\n", + " [0, 0, 3, 3],\n", + " [0, 1, 1, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(8, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [0, 1, 3, 1],\n", + " [0, 0, 1, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [0, 1, 3, 1],\n", + " [0, 0, 0, 1],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [0, 2, 3, 2],\n", + " [0, 0, 0, 0],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [0, 2, 3, 2],\n", + " [0, 0, 0, 1],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 2, 3, 2],\n", + " [0, 0, 0, 1],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 2, 3, 2],\n", + " [0, 0, 0, 2],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 2, 3, 3],\n", + " [0, 0, 0, 1],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 2, 3, 3],\n", + " [0, 1, 0, 1],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 2, 3, 3],\n", + " [0, 2, 1, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 3, 1, 4],\n", + " [1, 3, 3, 3],\n", + " [0, 0, 1, 1],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 4],\n", + " [1, 0, 3, 3],\n", + " [0, 1, 1, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 4],\n", + " [1, 1, 3, 3],\n", + " [0, 0, 1, 2],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 4],\n", + " [1, 1, 3, 3],\n", + " [0, 1, 2, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 4],\n", + " [1, 2, 3, 3],\n", + " [0, 1, 2, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 4],\n", + " [1, 1, 2, 4],\n", + " [0, 0, 1, 3],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(16, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 1, 2, 3],\n", + " [0, 1, 1, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 2, 2, 3],\n", + " [0, 0, 1, 0],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 2, 2, 3],\n", + " [0, 1, 1, 0],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 2, 2, 3],\n", + " [0, 1, 2, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 2, 3, 3],\n", + " [0, 1, 1, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 1, 2, 4],\n", + " [0, 0, 1, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [0, 2, 2, 4],\n", + " [0, 0, 1, 2],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [1, 2, 2, 4],\n", + " [0, 0, 1, 2],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 4, 1, 5],\n", + " [2, 2, 2, 4],\n", + " [1, 0, 1, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [1, 2, 2, 4],\n", + " [0, 0, 1, 2],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [1, 2, 2, 4],\n", + " [0, 1, 1, 2],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [1, 2, 2, 4],\n", + " [0, 1, 2, 2],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [2, 2, 3, 4],\n", + " [1, 1, 0, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [0, 3, 3, 4],\n", + " [0, 0, 2, 2],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [1, 3, 3, 4],\n", + " [0, 0, 2, 2],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [0, 1, 4, 4],\n", + " [0, 1, 0, 3],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [0, 2, 4, 4],\n", + " [0, 1, 0, 3],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 5],\n", + " [0, 0, 2, 5],\n", + " [1, 0, 1, 3],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(32, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 0, 2, 3],\n", + " [0, 1, 1, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 1, 2, 3],\n", + " [0, 0, 1, 1],\n", + " [2, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 1, 2, 3],\n", + " [2, 0, 1, 1],\n", + " [0, 2, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 1, 2, 3],\n", + " [2, 2, 1, 1],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 1, 2, 3],\n", + " [2, 2, 2, 1],\n", + " [0, 0, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 1, 3, 3],\n", + " [2, 2, 2, 1],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [0, 1, 2, 4],\n", + " [0, 2, 3, 1],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 4, 0],\n", + " [2, 3, 1, 1],\n", + " [2, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 4, 1],\n", + " [3, 3, 1, 0],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 4, 2],\n", + " [3, 3, 1, 0],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 4, 2],\n", + " [3, 3, 2, 0],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 4, 2],\n", + " [0, 0, 4, 2],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(40., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 5, 3],\n", + " [0, 1, 1, 1],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 5, 3],\n", + " [0, 0, 1, 2],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [1, 2, 5, 3],\n", + " [1, 2, 1, 0],\n", + " [1, 0, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [2, 3, 5, 3],\n", + " [1, 1, 1, 0],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [2, 3, 5, 3],\n", + " [0, 0, 1, 2],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [2, 3, 5, 3],\n", + " [0, 1, 1, 2],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 4, 1, 6],\n", + " [2, 3, 5, 3],\n", + " [1, 2, 1, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, False, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 2, 0, 0],\n", + " [3, 4, 1, 6],\n", + " [2, 3, 5, 3],\n", + " [1, 2, 1, 2]], dtype=int32), action_mask=Array([ True, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 6],\n", + " [2, 4, 5, 3],\n", + " [1, 3, 1, 2],\n", + " [0, 2, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 6],\n", + " [2, 4, 5, 3],\n", + " [1, 3, 1, 2],\n", + " [1, 0, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 6],\n", + " [2, 4, 5, 3],\n", + " [2, 3, 1, 2],\n", + " [0, 1, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 2, 1, 6],\n", + " [3, 4, 5, 3],\n", + " [1, 3, 1, 2],\n", + " [0, 1, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 3],\n", + " [0, 3, 1, 2],\n", + " [2, 1, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 3],\n", + " [2, 3, 1, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 3],\n", + " [2, 3, 1, 2],\n", + " [2, 2, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 3],\n", + " [3, 3, 2, 2],\n", + " [1, 2, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 3],\n", + " [0, 1, 4, 3],\n", + " [0, 1, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [1, 4, 5, 4],\n", + " [1, 2, 4, 1],\n", + " [0, 0, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 5, 4],\n", + " [0, 2, 4, 1],\n", + " [0, 0, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 5, 4],\n", + " [0, 2, 4, 2],\n", + " [0, 0, 2, 1]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 5, 4],\n", + " [2, 4, 2, 0],\n", + " [2, 1, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(40., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 5, 4],\n", + " [2, 1, 2, 0],\n", + " [0, 1, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 5, 4],\n", + " [2, 2, 2, 0],\n", + " [0, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 5, 4],\n", + " [2, 2, 2, 1],\n", + " [0, 2, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 5, 4],\n", + " [2, 3, 2, 1],\n", + " [0, 1, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(68., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [0, 3, 6, 4],\n", + " [2, 3, 2, 1],\n", + " [0, 1, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [1, 1, 2, 1],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [0, 2, 2, 1],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [1, 0, 3, 1],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [0, 1, 3, 1],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [0, 2, 3, 1],\n", + " [0, 2, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [0, 3, 3, 1],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [1, 3, 3, 1],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [2, 4, 6, 4],\n", + " [2, 1, 4, 1],\n", + " [0, 0, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 4, 6, 4],\n", + " [1, 1, 4, 1],\n", + " [0, 0, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 4, 6, 4],\n", + " [0, 2, 4, 1],\n", + " [0, 0, 1, 3]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 4, 6, 4],\n", + " [2, 4, 1, 0],\n", + " [1, 3, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [2, 3, 1, 1],\n", + " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [2, 3, 1, 2],\n", + " [1, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [2, 3, 1, 2],\n", + " [0, 0, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [2, 3, 1, 3],\n", + " [0, 0, 2, 1]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [2, 3, 1, 3],\n", + " [2, 1, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [3, 5, 6, 4],\n", + " [3, 3, 2, 3],\n", + " [0, 1, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 2, 1, 6],\n", + " [4, 5, 6, 4],\n", + " [1, 3, 3, 3],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 2, 1, 6],\n", + " [1, 5, 6, 4],\n", + " [0, 3, 3, 3],\n", + " [2, 1, 0, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 2, 1, 6],\n", + " [1, 5, 6, 4],\n", + " [2, 3, 3, 3],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 2, 1, 6],\n", + " [1, 5, 6, 4],\n", + " [0, 2, 3, 4],\n", + " [1, 0, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 2, 1, 6],\n", + " [2, 5, 6, 5],\n", + " [1, 2, 3, 2],\n", + " [0, 0, 0, 0]], dtype=int32), action_mask=Array([False, False, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 0, 1, 0],\n", + " [5, 2, 1, 6],\n", + " [2, 5, 6, 5],\n", + " [1, 2, 3, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 2, 2, 6],\n", + " [2, 5, 6, 5],\n", + " [1, 2, 3, 2],\n", + " [0, 1, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[0, 5, 3, 6],\n", + " [2, 5, 6, 5],\n", + " [1, 2, 3, 2],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 6, 3, 6],\n", + " [1, 2, 6, 5],\n", + " [0, 1, 3, 2],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 6, 3, 6],\n", + " [1, 2, 6, 5],\n", + " [0, 2, 3, 2],\n", + " [0, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 6, 3, 6],\n", + " [1, 3, 6, 5],\n", + " [0, 0, 3, 2],\n", + " [1, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[2, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [1, 0, 3, 2],\n", + " [0, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [1, 3, 6, 5],\n", + " [0, 0, 3, 2],\n", + " [1, 0, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [0, 1, 3, 2],\n", + " [0, 0, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [0, 1, 3, 2],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [1, 1, 3, 3],\n", + " [0, 0, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [0, 0, 2, 4],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [1, 1, 2, 4],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [0, 2, 2, 4],\n", + " [1, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [1, 2, 2, 4],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [2, 3, 6, 5],\n", + " [2, 1, 3, 4],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[3, 6, 3, 6],\n", + " [3, 3, 6, 5],\n", + " [0, 1, 3, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [0, 3, 6, 5],\n", + " [0, 1, 3, 4],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [0, 3, 6, 5],\n", + " [0, 2, 3, 4],\n", + " [0, 2, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [0, 3, 6, 5],\n", + " [0, 3, 3, 4],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [1, 4, 6, 5],\n", + " [0, 2, 3, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [1, 4, 6, 5],\n", + " [2, 3, 4, 0],\n", + " [1, 2, 2, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [1, 4, 6, 5],\n", + " [0, 2, 3, 4],\n", + " [1, 0, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [2, 4, 6, 5],\n", + " [1, 2, 3, 4],\n", + " [0, 0, 1, 3]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [2, 4, 6, 5],\n", + " [1, 2, 3, 4],\n", + " [1, 3, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [2, 4, 6, 5],\n", + " [2, 2, 3, 4],\n", + " [1, 3, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 5],\n", + " [1, 2, 3, 4],\n", + " [1, 3, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, False], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 5],\n", + " [2, 2, 3, 4],\n", + " [0, 3, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 5],\n", + " [0, 3, 3, 4],\n", + " [1, 0, 3, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 5],\n", + " [1, 3, 4, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 5],\n", + " [0, 1, 3, 5],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(68., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 6],\n", + " [3, 4, 6, 6],\n", + " [0, 2, 3, 2],\n", + " [2, 0, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(64, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(128., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 2, 3, 0],\n", + " [0, 1, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 2, 3, 3],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 0, 2, 4],\n", + " [0, 1, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 1, 2, 4],\n", + " [0, 1, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 2, 2, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 0, 3, 4],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 1, 3, 4],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 1, 3, 4],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 1, 3, 4],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 1, 3, 4],\n", + " [0, 1, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 2, 3, 4],\n", + " [0, 1, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 3, 3, 4],\n", + " [0, 0, 1, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 1, 4, 4],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 2, 4, 4],\n", + " [1, 0, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 2, 4, 4],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 1, 2, 5],\n", + " [0, 0, 2, 3]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 1, 3, 5],\n", + " [1, 0, 0, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 1, 3, 5],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 2, 3, 5],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 3, 3, 5],\n", + " [0, 0, 2, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 1, 4, 5],\n", + " [0, 2, 2, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [0, 1, 4, 5],\n", + " [1, 0, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 1, 4, 5],\n", + " [0, 1, 3, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 2, 4, 5],\n", + " [0, 1, 3, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 2, 4, 5],\n", + " [0, 1, 1, 4]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 2, 4, 5],\n", + " [0, 2, 2, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [1, 3, 4, 5],\n", + " [1, 0, 2, 4]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 3, 4, 5],\n", + " [0, 1, 2, 4]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 3, 4, 5],\n", + " [1, 2, 4, 1]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 3, 5, 5],\n", + " [1, 2, 1, 1]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(68., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 6, 2],\n", + " [2, 3, 6, 1],\n", + " [1, 2, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(128., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 2],\n", + " [2, 3, 2, 1],\n", + " [1, 2, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 2],\n", + " [2, 3, 3, 1],\n", + " [1, 2, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 2],\n", + " [2, 3, 3, 2],\n", + " [1, 2, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [2, 3, 3, 1],\n", + " [1, 2, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [2, 3, 3, 2],\n", + " [1, 2, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [0, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [1, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [2, 2, 4, 2],\n", + " [2, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [3, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[4, 6, 3, 7],\n", + " [4, 4, 7, 3],\n", + " [1, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [2, 4, 7, 3],\n", + " [0, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [2, 4, 7, 3],\n", + " [1, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [2, 4, 7, 3],\n", + " [2, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [1, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [2, 2, 4, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [3, 4, 7, 3],\n", + " [3, 4, 2, 0],\n", + " [2, 2, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(48., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [2, 2, 2, 1],\n", + " [1, 0, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 2, 3, 1],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 2, 3, 1],\n", + " [2, 0, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 2, 3, 2],\n", + " [2, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 2, 3, 2],\n", + " [1, 0, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [2, 2, 3, 2],\n", + " [1, 0, 2, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [0, 3, 3, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 3, 3, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [2, 3, 3, 2],\n", + " [1, 1, 2, 1]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(20., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [2, 4, 2, 1],\n", + " [2, 2, 1, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [3, 4, 2, 1],\n", + " [0, 2, 1, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [3, 4, 2, 2],\n", + " [0, 2, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 3],\n", + " [1, 3, 4, 3],\n", + " [0, 0, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [1, 3, 4, 2],\n", + " [0, 0, 2, 1]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [1, 3, 4, 2],\n", + " [2, 1, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [1, 3, 4, 2],\n", + " [1, 0, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 2, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [2, 3, 4, 3],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 4],\n", + " [2, 3, 4, 4],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(32., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 2],\n", + " [0, 1, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 2],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 3],\n", + " [0, 1, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 3],\n", + " [1, 0, 1, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [2, 3, 4, 4],\n", + " [1, 0, 1, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 5],\n", + " [1, 2, 3, 5],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(64., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [1, 2, 3, 2],\n", + " [0, 1, 1, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [1, 2, 3, 2],\n", + " [1, 0, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 2, 3, 3],\n", + " [0, 0, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(24., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 0, 3, 4],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 1, 3, 4],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 2, 3, 4],\n", + " [0, 1, 0, 1]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 2, 3, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [1, 2, 0, 0]], dtype=int32), action_mask=Array([False, True, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [0, 1, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [1, 0, 2, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [1, 1, 2, 3]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 4, 1],\n", + " [2, 2, 3, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [3, 3, 4, 2],\n", + " [1, 2, 3, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 4, 4, 2],\n", + " [1, 1, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [1, 4, 4, 2],\n", + " [1, 1, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 4, 4, 2],\n", + " [1, 1, 2, 3]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 5, 2, 0],\n", + " [2, 2, 3, 1]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(72., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 6, 3, 7],\n", + " [4, 6, 7, 6],\n", + " [3, 2, 2, 1],\n", + " [1, 0, 3, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(128., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [3, 0, 2, 1],\n", + " [1, 0, 3, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 3, 2, 1],\n", + " [0, 1, 3, 2]], dtype=int32), action_mask=Array([False, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 3, 2, 1],\n", + " [1, 3, 2, 1]], dtype=int32), action_mask=Array([ True, False, True, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(28., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 4, 3, 2],\n", + " [1, 1, 0, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 4, 3, 2],\n", + " [0, 0, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 4, 3, 3],\n", + " [1, 0, 2, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 2, 7, 6],\n", + " [2, 2, 4, 4],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 3, 7, 6],\n", + " [2, 0, 4, 4],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 3, 7, 6],\n", + " [0, 0, 2, 5],\n", + " [0, 2, 2, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 3, 7, 6],\n", + " [0, 2, 3, 5],\n", + " [0, 2, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 3, 7, 6],\n", + " [0, 3, 3, 5],\n", + " [0, 1, 0, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 4, 7, 6],\n", + " [1, 1, 3, 5],\n", + " [0, 0, 0, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(36., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [0, 5, 7, 6],\n", + " [0, 2, 3, 5],\n", + " [0, 1, 0, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [0, 5, 7, 6],\n", + " [1, 2, 3, 5],\n", + " [0, 0, 1, 2]], dtype=int32), action_mask=Array([ True, False, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [1, 5, 7, 6],\n", + " [0, 2, 3, 5],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [2, 5, 7, 6],\n", + " [0, 2, 3, 5],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [2, 5, 7, 6],\n", + " [1, 2, 3, 5],\n", + " [1, 0, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [2, 5, 7, 6],\n", + " [2, 2, 3, 5],\n", + " [0, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [3, 5, 7, 6],\n", + " [0, 2, 3, 5],\n", + " [1, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [3, 5, 7, 6],\n", + " [1, 2, 3, 5],\n", + " [1, 1, 1, 2]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [3, 5, 7, 6],\n", + " [2, 2, 3, 5],\n", + " [1, 1, 1, 2]], dtype=int32), action_mask=Array([False, True, False, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(12., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [3, 5, 7, 6],\n", + " [3, 3, 5, 1],\n", + " [2, 1, 2, 0]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 5, 1],\n", + " [1, 1, 2, 0]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 5, 1],\n", + " [1, 0, 2, 2]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 5, 1],\n", + " [0, 1, 1, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(4., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [2, 3, 5, 1],\n", + " [2, 0, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(8., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [3, 3, 5, 1],\n", + " [1, 0, 2, 3]], dtype=int32), action_mask=Array([False, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(1, dtype=int8), reward=Array(16., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [0, 4, 5, 1],\n", + " [1, 1, 2, 3]], dtype=int32), action_mask=Array([ True, True, True, True], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "TimeStep(step_type=Array(2, dtype=int8), reward=Array(0., dtype=float32), discount=Array(0., dtype=float32), observation=Observation(board=Array([[5, 7, 3, 7],\n", + " [4, 5, 7, 6],\n", + " [1, 4, 5, 1],\n", + " [2, 1, 2, 3]], dtype=int32), action_mask=Array([False, False, False, False], dtype=bool)), extras={'highest_tile': Array(128, dtype=int32)})\n", + "3004.0\n" + ] + } + ], + "source": [ + "import jax, jumanji\n", + "\n", + "env = jumanji.make(\"Game2048-v1\")\n", + "key = jax.random.PRNGKey(48)\n", + "jit_reset = jax.jit(env.reset)\n", + "jit_step = jax.jit(env.step)\n", + "state, timestep = jax.jit(env.reset)(key)\n", + "jit_policy = jax.jit(policy)\n", + "total_reward = 0\n", + "while True:\n", + " board, action_mask = timestep[\"observation\"]\n", + " action = jit_policy(timestep[\"observation\"][0].reshape(-1))\n", + " score_with_mask = jnp.where(action_mask, action, -jnp.inf)\n", + " action = jnp.argmax(score_with_mask)\n", + " state, timestep = jit_step(state, action)\n", + " done = jnp.all(~timestep[\"observation\"][1])\n", + " print(timestep)\n", + " total_reward += timestep[\"reward\"]\n", + " if done:\n", + " break\n", + "print(total_reward)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-06-05T05:15:43.041491500Z", + "start_time": "2024-06-05T05:15:37.325953600Z" + } + }, + "id": "f166e09c5be1a8fb" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "187326d08ac1eeb4" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index 3d67dae..dd19ac6 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -46,7 +46,7 @@ class FuncFit(BaseProblem): def show(self, state, randkey, act_func, params, *args, **kwargs): predict = jax.vmap(act_func, in_axes=(None, None, 0))( - state, params, self.inputs, params + state, params, self.inputs ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) if self.return_data: diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index eff6338..7df8040 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -5,8 +5,8 @@ from .rl_jit import RLEnv class BraxEnv(RLEnv): - def __init__(self, max_step=1000, record_episode=False, env_name: str = "ant", backend: str = "generalized"): - super().__init__(max_step, record_episode) + def __init__(self, max_step=1000, repeat_times=1, record_episode=False, env_name: str = "ant", backend: str = "generalized"): + super().__init__(max_step, repeat_times, record_episode) self.env = envs.create(env_name=env_name, backend=backend) def env_step(self, randkey, env_state, action): diff --git a/tensorneat/problem/rl_env/gymnax_env.py b/tensorneat/problem/rl_env/gymnax_env.py index e0ef323..af75d60 100644 --- a/tensorneat/problem/rl_env/gymnax_env.py +++ b/tensorneat/problem/rl_env/gymnax_env.py @@ -4,8 +4,8 @@ from .rl_jit import RLEnv class GymNaxEnv(RLEnv): - def __init__(self, env_name, max_step=1000, record_episode=False): - super().__init__(max_step, record_episode) + def __init__(self, env_name, max_step=1000, repeat_times=1, record_episode=False): + super().__init__(max_step, repeat_times, record_episode) assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" self.env, self.env_params = gymnax.make(env_name) diff --git a/tensorneat/problem/rl_env/jumanji/__init__.py b/tensorneat/problem/rl_env/jumanji/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tensorneat/problem/rl_env/jumanji/jumanji_2048.py b/tensorneat/problem/rl_env/jumanji/jumanji_2048.py new file mode 100644 index 0000000..e9b7274 --- /dev/null +++ b/tensorneat/problem/rl_env/jumanji/jumanji_2048.py @@ -0,0 +1,56 @@ +import jax, jax.numpy as jnp +import jumanji + +from utils import State +from ..rl_jit import RLEnv + + +class Jumanji_2048(RLEnv): + def __init__( + self, max_step=1000, repeat_times=1, record_episode=False, guarantee_invalid_action=True + ): + super().__init__(max_step, repeat_times, record_episode) + self.guarantee_invalid_action = guarantee_invalid_action + self.env = jumanji.make("Game2048-v1") + + def env_step(self, randkey, env_state, action): + action_mask = env_state["action_mask"] + if self.guarantee_invalid_action: + score_with_mask = jnp.where(action_mask, action, -jnp.inf) + action = jnp.argmax(score_with_mask) + else: + action = jnp.argmax(action) + + done = ~action_mask[action] + + env_state, timestep = self.env.step(env_state, action) + reward = timestep["reward"] + + board, action_mask = timestep["observation"] + extras = timestep["extras"] + + done = done | (jnp.sum(action_mask) == 0) # all actions of invalid + + return board.reshape(-1), env_state, reward, done, extras + + def env_reset(self, randkey): + env_state, timestep = self.env.reset(randkey) + step_type = timestep["step_type"] + reward = timestep["reward"] + discount = timestep["discount"] + observation = timestep["observation"] + extras = timestep["extras"] + board, action_mask = observation + + return board.reshape(-1), env_state + + @property + def input_shape(self): + return (16,) + + @property + def output_shape(self): + return (4,) + + def show(self, state, randkey, act_func, params, *args, **kwargs): + raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 00dfcb3..285a9a6 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -1,20 +1,47 @@ from functools import partial +from typing import Callable import jax import jax.numpy as jnp +from utils import State from .. import BaseProblem class RLEnv(BaseProblem): jitable = True - def __init__(self, max_step=1000, record_episode=False): + def __init__(self, max_step=1000, repeat_times=1, record_episode=False): super().__init__() self.max_step = max_step self.record_episode = record_episode + self.repeat_times = repeat_times - def evaluate(self, state, randkey, act_func, params): + def evaluate(self, state: State, randkey, act_func: Callable, params): + keys = jax.random.split(randkey, self.repeat_times) + if self.record_episode: + rewards, episodes = jax.vmap( + self.evaluate_once, in_axes=(None, 0, None, None) + )(state, keys, act_func, params) + episodes["obs"] = episodes["obs"].reshape( + self.max_step * self.repeat_times, *self.input_shape + ) + episodes["action"] = episodes["action"].reshape( + self.max_step * self.repeat_times, *self.output_shape + ) + episodes["reward"] = episodes["reward"].reshape( + self.max_step * self.repeat_times, + ) + + return rewards.mean(), episodes + + else: + rewards = jax.vmap(self.evaluate_once, in_axes=(None, 0, None, None))( + state, keys, act_func, params + ) + return rewards.mean() + + def evaluate_once(self, state, randkey, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset)