diff --git a/hourglass_tensorflow/metrics/correct_keypoints.py b/hourglass_tensorflow/metrics/correct_keypoints.py index d19195c..726da44 100644 --- a/hourglass_tensorflow/metrics/correct_keypoints.py +++ b/hourglass_tensorflow/metrics/correct_keypoints.py @@ -1,10 +1,10 @@ import tensorflow as tf -import keras.metrics +from keras.metrics import Metric from hourglass_tensorflow.utils.tf import tf_dynamic_matrix_argmax -class RatioCorrectKeypoints(keras.metrics.Metric): +class RatioCorrectKeypoints(Metric): """RatioCorrectKeypoints metric identifies the percentage of "true positive" keypoints detected This metric binarize our heatmap generation model (Regression Problem), @@ -50,7 +50,7 @@ def argmax_tensor(self, tensor): keepdims=True, ) - def update_state(self, y_true, y_pred, *args, **kwargs): + def _internal_update(self, y_true, y_pred): ground_truth_joints = self.argmax_tensor(y_true) predicted_joints = self.argmax_tensor(y_pred) distance = ground_truth_joints - predicted_joints @@ -65,15 +65,18 @@ def update_state(self, y_true, y_pred, *args, **kwargs): self.correct_keypoints.assign_add(correct_keypoints) self.total_keypoints.assign_add(total_keypoints) + def update_state(self, y_true, y_pred, *args, **kwargs): + return self._internal_update() + def result(self, *args, **kwargs): - return self.correct_keypoints / self.total_keypoints + return tf.math.divide_no_nan(self.correct_keypoints, self.total_keypoints) - def reset_states(self) -> None: + def reset_state(self) -> None: self.correct_keypoints.assign(0.0) self.total_keypoints.assign(0.0) -class PercentageOfCorrectKeypoints(keras.metrics.Metric): +class PercentageOfCorrectKeypoints(Metric): """PercentageOfCorrectKeypoints metric measures if predicted keypoint and true joint are within a distance threshold PCK is used as an accuracy metric that measures if the predicted keypoint and the true joint are within @@ -128,7 +131,7 @@ def argmax_tensor(self, tensor): keepdims=True, ) - def update_state(self, y_true, y_pred, *args, **kwargs): + def _internal_update(self, y_true, y_pred): ground_truth_joints = self.argmax_tensor(y_true) predicted_joints = self.argmax_tensor(y_pred) # We compute distance between ground truth and prediction @@ -152,15 +155,18 @@ def update_state(self, y_true, y_pred, *args, **kwargs): self.correct_keypoints.assign_add(correct_keypoints) self.total_keypoints.assign_add(total_keypoints) + def update_state(self, y_true, y_pred, *args, **kwargs): + return self._internal_update(y_true, y_pred) + def result(self, *args, **kwargs): - return self.correct_keypoints / self.total_keypoints + return tf.math.divide_no_nan(self.correct_keypoints, self.total_keypoints) - def reset_states(self) -> None: + def reset_state(self) -> None: self.correct_keypoints.assign(0.0) self.total_keypoints.assign(0.0) -class ObjectKeypointSimilarity(keras.metrics.Metric): +class ObjectKeypointSimilarity(Metric): """ObjectKeypointSimilarity metric measures if predicted keypoint and true joint are within a distance threshold OKS is commonly used in the COCO keypoint challenge as an evaluation metric. It is calculated from @@ -214,7 +220,7 @@ def __init__( # Set default value pass - def reset_states(self) -> None: + def reset_state(self) -> None: self.oks_sum.assign(0.0) self.samples.assign(0.0) diff --git a/hourglass_tensorflow/metrics/distance.py b/hourglass_tensorflow/metrics/distance.py index 45a96a0..34eb2a8 100644 --- a/hourglass_tensorflow/metrics/distance.py +++ b/hourglass_tensorflow/metrics/distance.py @@ -21,7 +21,7 @@ def argmax_tensor(self, tensor): keepdims=True, ) - def update_state(self, y_true, y_pred, *args, **kwargs): + def _internal_update(self, y_true, y_pred): ground_truth_joints = self.argmax_tensor(y_true) predicted_joints = self.argmax_tensor(y_pred) distance = tf.cast( @@ -31,9 +31,12 @@ def update_state(self, y_true, y_pred, *args, **kwargs): self.distance.assign_add(mean_distance) self.batches.assign_add(1.0) - def result(self, *args, **kwargs): - return self.distance / self.batches + def update_state(self, y_true, y_pred, *args, **kwargs): + return self._internal_update(y_true, y_pred) + + def result(self): + return tf.math.divide_no_nan(self.distance, self.batches) - def reset_states(self) -> None: + def reset_state(self) -> None: self.batches.assign(0.0) self.distance.assign(0.0)