Skip to content

Commit

Permalink
Change dtype in inner_reshape to be int32 instead of int64.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452825469
Change-Id: Ie4602a0a024fc4c03399223e934996baeb086130
  • Loading branch information
Yao Lu authored and copybara-github committed Jun 3, 2022
1 parent 4b2203f commit 340d52c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tf_agents/keras_layers/inner_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def _reshape_inner_dims(
ndims = shape.rank
tensor.shape[-ndims:].assert_is_compatible_with(shape)
new_shape_inner_tensor = tf.cast(
[-1 if d is None else d for d in new_shape.as_list()], tf.int64)
[-1 if d is None else d for d in new_shape.as_list()], tf.int32)
new_shape_outer_tensor = tf.cast(
tensor_shape[:-ndims], tf.int64)
tensor_shape[:-ndims], tf.int32)
full_new_shape = tf.concat(
(new_shape_outer_tensor, new_shape_inner_tensor), axis=0)
new_tensor = tf.reshape(tensor, full_new_shape)
Expand Down

0 comments on commit 340d52c

Please sign in to comment.