Skip to content

Commit

Permalink
[FIX](metric) PCK computation
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 25, 2022
1 parent 31e45c0 commit 57623c2
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions hourglass_tensorflow/metrics/correct_keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,35 +132,22 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
ground_truth_joints = self.argmax_tensor(y_true)
predicted_joints = self.argmax_tensor(y_pred)
# We compute distance between ground truth and prediction
distance = tf.norm(
tf.cast(ground_truth_joints - predicted_joints, dtype=tf.dtypes.float32),
ord=2,
axis=-1,
)
error = tf.cast(ground_truth_joints - predicted_joints, dtype=tf.dtypes.float32)
distance = tf.norm(error, ord=2, axis=-1)
# We compute the norm of the reference limb from the ground truth
reference_distance = tf.expand_dims(
tf.norm(
tf.cast(
ground_truth_joints[:, self.reference[0], :]
- ground_truth_joints[:, self.reference[1], :],
tf.dtypes.float32,
),
axis=1,
),
axis=-1,
reference_limb_error = tf.cast(
ground_truth_joints[:, self.reference[0], :]
- ground_truth_joints[:, self.reference[1], :],
dtype=tf.float32,
)
reference_distance = tf.norm(reference_limb_error, ord=2, axis=-1)
# We apply the thresholding condition
correct_keypoints = tf.cast(
tf.reduce_sum(
tf.cast(
distance < (tf.expand_dims(reference_distance, -1) * self.ratio),
dtype=tf.dtypes.int32,
)
),
dtype=tf.dtypes.float32,
condition = tf.cast(
distance < (reference_distance * self.ratio), dtype=tf.float32
)
correct_keypoints = tf.reduce_sum(condition)
total_keypoints = tf.cast(
tf.reduce_prod(tf.shape(reference_distance)), dtype=tf.dtypes.float32
tf.reduce_prod(tf.shape(distance)), dtype=tf.dtypes.float32
)
self.correct_keypoints.assign_add(correct_keypoints)
self.total_keypoints.assign_add(total_keypoints)
Expand Down

0 comments on commit 57623c2

Please sign in to comment.