Skip to content

Commit

Permalink
[FIX](tf) 2D Argmax function tensor index
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 24, 2022
1 parent 62498ab commit 93c9c44
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions hourglass_tensorflow/utils/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def tf_matrix_argmax(tensor: tf.Tensor) -> tf.Tensor:
"""
flat_tensor = tf.reshape(tensor, (-1, tf.shape(tensor)[-1]))
argmax = tf.cast(tf.argmax(flat_tensor, axis=0), tf.int32)
argmax_x = argmax // tf.shape(tensor)[2]
argmax_y = argmax % tf.shape(tensor)[2]
argmax_x = argmax // tf.shape(tensor)[1]
argmax_y = argmax % tf.shape(tensor)[1]
# stack and return 2D coordinates
return tf.transpose(tf.stack((argmax_x, argmax_y), axis=0), [1, 0])

Expand Down

0 comments on commit 93c9c44

Please sign in to comment.