diff --git a/tf_agents/keras_layers/inner_reshape.py b/tf_agents/keras_layers/inner_reshape.py index 36176a25c..cacfb01c1 100644 --- a/tf_agents/keras_layers/inner_reshape.py +++ b/tf_agents/keras_layers/inner_reshape.py @@ -105,9 +105,9 @@ def _reshape_inner_dims( ndims = shape.rank tensor.shape[-ndims:].assert_is_compatible_with(shape) new_shape_inner_tensor = tf.cast( - [-1 if d is None else d for d in new_shape.as_list()], tf.int64) + [-1 if d is None else d for d in new_shape.as_list()], tf.int32) new_shape_outer_tensor = tf.cast( - tensor_shape[:-ndims], tf.int64) + tensor_shape[:-ndims], tf.int32) full_new_shape = tf.concat( (new_shape_outer_tensor, new_shape_inner_tensor), axis=0) new_tensor = tf.reshape(tensor, full_new_shape)