Skip to content

Commit

Permalink
[FIX](metrics) Avoid NaN metric value with safe divide.
Browse files Browse the repository at this point in the history
Enable _internal_state_update to bypass tensorflow coverage.py issue
  • Loading branch information
wbenbihi committed Aug 26, 2022
1 parent 949d399 commit ef58e54
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
28 changes: 17 additions & 11 deletions hourglass_tensorflow/metrics/correct_keypoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import tensorflow as tf
import keras.metrics
from keras.metrics import Metric

from hourglass_tensorflow.utils.tf import tf_dynamic_matrix_argmax


class RatioCorrectKeypoints(keras.metrics.Metric):
class RatioCorrectKeypoints(Metric):
"""RatioCorrectKeypoints metric identifies the percentage of "true positive" keypoints detected
This metric binarize our heatmap generation model (Regression Problem),
Expand Down Expand Up @@ -50,7 +50,7 @@ def argmax_tensor(self, tensor):
keepdims=True,
)

def update_state(self, y_true, y_pred, *args, **kwargs):
def _internal_update(self, y_true, y_pred):
ground_truth_joints = self.argmax_tensor(y_true)
predicted_joints = self.argmax_tensor(y_pred)
distance = ground_truth_joints - predicted_joints
Expand All @@ -65,15 +65,18 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
self.correct_keypoints.assign_add(correct_keypoints)
self.total_keypoints.assign_add(total_keypoints)

def update_state(self, y_true, y_pred, *args, **kwargs):
return self._internal_update()

def result(self, *args, **kwargs):
return self.correct_keypoints / self.total_keypoints
return tf.math.divide_no_nan(self.correct_keypoints, self.total_keypoints)

def reset_states(self) -> None:
def reset_state(self) -> None:
self.correct_keypoints.assign(0.0)
self.total_keypoints.assign(0.0)


class PercentageOfCorrectKeypoints(keras.metrics.Metric):
class PercentageOfCorrectKeypoints(Metric):
"""PercentageOfCorrectKeypoints metric measures if predicted keypoint and true joint are within a distance threshold
PCK is used as an accuracy metric that measures if the predicted keypoint and the true joint are within
Expand Down Expand Up @@ -128,7 +131,7 @@ def argmax_tensor(self, tensor):
keepdims=True,
)

def update_state(self, y_true, y_pred, *args, **kwargs):
def _internal_update(self, y_true, y_pred):
ground_truth_joints = self.argmax_tensor(y_true)
predicted_joints = self.argmax_tensor(y_pred)
# We compute distance between ground truth and prediction
Expand All @@ -152,15 +155,18 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
self.correct_keypoints.assign_add(correct_keypoints)
self.total_keypoints.assign_add(total_keypoints)

def update_state(self, y_true, y_pred, *args, **kwargs):
return self._internal_update(y_true, y_pred)

def result(self, *args, **kwargs):
return self.correct_keypoints / self.total_keypoints
return tf.math.divide_no_nan(self.correct_keypoints, self.total_keypoints)

def reset_states(self) -> None:
def reset_state(self) -> None:
self.correct_keypoints.assign(0.0)
self.total_keypoints.assign(0.0)


class ObjectKeypointSimilarity(keras.metrics.Metric):
class ObjectKeypointSimilarity(Metric):
"""ObjectKeypointSimilarity metric measures if predicted keypoint and true joint are within a distance threshold
OKS is commonly used in the COCO keypoint challenge as an evaluation metric. It is calculated from
Expand Down Expand Up @@ -214,7 +220,7 @@ def __init__(
# Set default value
pass

def reset_states(self) -> None:
def reset_state(self) -> None:
self.oks_sum.assign(0.0)
self.samples.assign(0.0)

Expand Down
11 changes: 7 additions & 4 deletions hourglass_tensorflow/metrics/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def argmax_tensor(self, tensor):
keepdims=True,
)

def update_state(self, y_true, y_pred, *args, **kwargs):
def _internal_update(self, y_true, y_pred):
ground_truth_joints = self.argmax_tensor(y_true)
predicted_joints = self.argmax_tensor(y_pred)
distance = tf.cast(
Expand All @@ -31,9 +31,12 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
self.distance.assign_add(mean_distance)
self.batches.assign_add(1.0)

def result(self, *args, **kwargs):
return self.distance / self.batches
def update_state(self, y_true, y_pred, *args, **kwargs):
return self._internal_update(y_true, y_pred)

def result(self):
return tf.math.divide_no_nan(self.distance, self.batches)

def reset_states(self) -> None:
def reset_state(self) -> None:
self.batches.assign(0.0)
self.distance.assign(0.0)

0 comments on commit ef58e54

Please sign in to comment.