diff --git a/hourglass_tensorflow/metrics/correct_keypoints.py b/hourglass_tensorflow/metrics/correct_keypoints.py index 90eb80c..13acb8a 100644 --- a/hourglass_tensorflow/metrics/correct_keypoints.py +++ b/hourglass_tensorflow/metrics/correct_keypoints.py @@ -55,10 +55,13 @@ def update_state(self, y_true, y_pred, *args, **kwargs): predicted_joints = self.argmax_tensor(y_pred) distance = ground_truth_joints - predicted_joints norms = tf.norm(tf.cast(distance, dtype=tf.dtypes.float32), ord=2, axis=-1) - correct_keypoints = tf.reduce_sum( - tf.cast(norms < self.threshold, dtype=tf.dtypes.int32) + correct_keypoints = tf.cast( + tf.reduce_sum(tf.cast(norms < self.threshold, dtype=tf.dtypes.int32)), + dtype=tf.dtypes.float32, + ) + total_keypoints = tf.cast( + tf.reduce_prod(tf.shape(norms)), dtype=tf.dtypes.float32 ) - total_keypoints = tf.reduce_prod(tf.shape(norms)) self.correct_keypoints.assign_add(correct_keypoints) self.total_keypoints.assign_add(total_keypoints) @@ -66,8 +69,8 @@ def result(self, *args, **kwargs): return self.correct_keypoints / self.total_keypoints def reset_states(self) -> None: - self.correct_keypoints.assign(0) - self.total_keypoints.assign(0) + self.correct_keypoints.assign(0.0) + self.total_keypoints.assign(0.0) class PercentageOfCorrectKeypoints(keras.metrics.Metric): @@ -147,13 +150,18 @@ def update_state(self, y_true, y_pred, *args, **kwargs): axis=-1, ) # We apply the thresholding condition - correct_keypoints = tf.reduce_sum( - tf.cast( - distance < (tf.expand_dims(reference_distance, -1) * self.ratio), - dtype=tf.dtypes.int32, - ) + correct_keypoints = tf.cast( + tf.reduce_sum( + tf.cast( + distance < (tf.expand_dims(reference_distance, -1) * self.ratio), + dtype=tf.dtypes.int32, + ) + ), + dtype=tf.dtypes.float32, + ) + total_keypoints = tf.cast( + tf.reduce_prod(tf.shape(reference_distance)), dtype=tf.dtypes.float32 ) - total_keypoints = tf.reduce_prod(tf.shape(reference_distance)) self.correct_keypoints.assign_add(correct_keypoints) self.total_keypoints.assign_add(total_keypoints) @@ -161,8 +169,8 @@ def result(self, *args, **kwargs): return self.correct_keypoints / self.total_keypoints def reset_states(self) -> None: - self.correct_keypoints.assign(0) - self.total_keypoints.assign(0) + self.correct_keypoints.assign(0.0) + self.total_keypoints.assign(0.0) class ObjectKeypointSimilarity(keras.metrics.Metric):