Skip to content

Commit

Permalink
[ADD][FEAT](func) Add 2d argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 23, 2022
1 parent 7384c59 commit 3ca77bc
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions hourglass_tensorflow/utils/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,36 @@ def tf_bivariate_normal_pdf(
factor = tf.cast(1.0 / (2.0 * m.pi * tf.reduce_prod(stddev)), precision)
Z = factor * tf.exp(-0.5 * R)
return Z


@tf.function
def tf_matrix_argmax(tensor: tf.Tensor) -> tf.Tensor:
"""Apply a 2D argmax to a tensor
Args:
tensor (tf.Tensor): 3D Tensor with data format HWC
Returns:
tf.Tensor: tf.dtypes.int32 Tensor of dimension Cx2
"""
flat_tensor = tf.reshape(tensor, (-1, tf.shape(tensor)[-1]))
argmax = tf.cast(tf.argmax(flat_tensor, axis=0), tf.int32)
argmax_x = argmax // tf.shape(tensor)[2]
argmax_y = argmax % tf.shape(tensor)[2]
# stack and return 2D coordinates
return tf.transpose(tf.stack((argmax_x, argmax_y), axis=0), [1, 0])


@tf.function
def tf_batch_matrix_argmax(tensor: tf.Tensor) -> tf.Tensor:
"""Apply 2D argmax along a batch
Args:
tensor (tf.Tensor): 4D Tensor with data format NHWC
Returns:
tf.Tensor: tf.dtypes.int32 Tensor of dimension NxCx2
"""
return tf.map_fn(
fn=tf_matrix_argmax, elems=tensor, fn_output_signature=tf.dtypes.int32
)

0 comments on commit 3ca77bc

Please sign in to comment.