Skip to content

Commit

Permalink
[FIX](metrics) cast metric results to float32
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 25, 2022
1 parent 5941db2 commit bf57506
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions hourglass_tensorflow/metrics/correct_keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,22 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
predicted_joints = self.argmax_tensor(y_pred)
distance = ground_truth_joints - predicted_joints
norms = tf.norm(tf.cast(distance, dtype=tf.dtypes.float32), ord=2, axis=-1)
correct_keypoints = tf.reduce_sum(
tf.cast(norms < self.threshold, dtype=tf.dtypes.int32)
correct_keypoints = tf.cast(
tf.reduce_sum(tf.cast(norms < self.threshold, dtype=tf.dtypes.int32)),
dtype=tf.dtypes.float32,
)
total_keypoints = tf.cast(
tf.reduce_prod(tf.shape(norms)), dtype=tf.dtypes.float32
)
total_keypoints = tf.reduce_prod(tf.shape(norms))
self.correct_keypoints.assign_add(correct_keypoints)
self.total_keypoints.assign_add(total_keypoints)

def result(self, *args, **kwargs):
return self.correct_keypoints / self.total_keypoints

def reset_states(self) -> None:
self.correct_keypoints.assign(0)
self.total_keypoints.assign(0)
self.correct_keypoints.assign(0.0)
self.total_keypoints.assign(0.0)


class PercentageOfCorrectKeypoints(keras.metrics.Metric):
Expand Down Expand Up @@ -147,22 +150,27 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
axis=-1,
)
# We apply the thresholding condition
correct_keypoints = tf.reduce_sum(
tf.cast(
distance < (tf.expand_dims(reference_distance, -1) * self.ratio),
dtype=tf.dtypes.int32,
)
correct_keypoints = tf.cast(
tf.reduce_sum(
tf.cast(
distance < (tf.expand_dims(reference_distance, -1) * self.ratio),
dtype=tf.dtypes.int32,
)
),
dtype=tf.dtypes.float32,
)
total_keypoints = tf.cast(
tf.reduce_prod(tf.shape(reference_distance)), dtype=tf.dtypes.float32
)
total_keypoints = tf.reduce_prod(tf.shape(reference_distance))
self.correct_keypoints.assign_add(correct_keypoints)
self.total_keypoints.assign_add(total_keypoints)

def result(self, *args, **kwargs):
return self.correct_keypoints / self.total_keypoints

def reset_states(self) -> None:
self.correct_keypoints.assign(0)
self.total_keypoints.assign(0)
self.correct_keypoints.assign(0.0)
self.total_keypoints.assign(0.0)


class ObjectKeypointSimilarity(keras.metrics.Metric):
Expand Down

0 comments on commit bf57506

Please sign in to comment.