diff --git a/hourglass_tensorflow/utils/tf.py b/hourglass_tensorflow/utils/tf.py index f83957d..2ef0751 100644 --- a/hourglass_tensorflow/utils/tf.py +++ b/hourglass_tensorflow/utils/tf.py @@ -198,3 +198,36 @@ def tf_bivariate_normal_pdf( factor = tf.cast(1.0 / (2.0 * m.pi * tf.reduce_prod(stddev)), precision) Z = factor * tf.exp(-0.5 * R) return Z + + +@tf.function +def tf_matrix_argmax(tensor: tf.Tensor) -> tf.Tensor: + """Apply a 2D argmax to a tensor + + Args: + tensor (tf.Tensor): 3D Tensor with data format HWC + + Returns: + tf.Tensor: tf.dtypes.int32 Tensor of dimension Cx2 + """ + flat_tensor = tf.reshape(tensor, (-1, tf.shape(tensor)[-1])) + argmax = tf.cast(tf.argmax(flat_tensor, axis=0), tf.int32) + argmax_x = argmax // tf.shape(tensor)[2] + argmax_y = argmax % tf.shape(tensor)[2] + # stack and return 2D coordinates + return tf.transpose(tf.stack((argmax_x, argmax_y), axis=0), [1, 0]) + + +@tf.function +def tf_batch_matrix_argmax(tensor: tf.Tensor) -> tf.Tensor: + """Apply 2D argmax along a batch + + Args: + tensor (tf.Tensor): 4D Tensor with data format NHWC + + Returns: + tf.Tensor: tf.dtypes.int32 Tensor of dimension NxCx2 + """ + return tf.map_fn( + fn=tf_matrix_argmax, elems=tensor, fn_output_signature=tf.dtypes.int32 + )