From 340d52c4dd84c7f31c3f03f05c85c4db8b494162 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Fri, 3 Jun 2022 12:56:43 -0700 Subject: [PATCH] Change dtype in inner_reshape to be int32 instead of int64. PiperOrigin-RevId: 452825469 Change-Id: Ie4602a0a024fc4c03399223e934996baeb086130 --- tf_agents/keras_layers/inner_reshape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_agents/keras_layers/inner_reshape.py b/tf_agents/keras_layers/inner_reshape.py index 36176a25c..cacfb01c1 100644 --- a/tf_agents/keras_layers/inner_reshape.py +++ b/tf_agents/keras_layers/inner_reshape.py @@ -105,9 +105,9 @@ def _reshape_inner_dims( ndims = shape.rank tensor.shape[-ndims:].assert_is_compatible_with(shape) new_shape_inner_tensor = tf.cast( - [-1 if d is None else d for d in new_shape.as_list()], tf.int64) + [-1 if d is None else d for d in new_shape.as_list()], tf.int32) new_shape_outer_tensor = tf.cast( - tensor_shape[:-ndims], tf.int64) + tensor_shape[:-ndims], tf.int32) full_new_shape = tf.concat( (new_shape_outer_tensor, new_shape_inner_tensor), axis=0) new_tensor = tf.reshape(tensor, full_new_shape)