Skip to content

Commit

Permalink
[ADD][FEAT](metrics) Add OKS metric
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 24, 2022
1 parent 2013c7c commit d9c8ade
Showing 1 changed file with 112 additions and 4 deletions.
116 changes: 112 additions & 4 deletions hourglass_tensorflow/metrics/correct_keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class PercentageOfCorrectKeypoints(keras.metrics.Metric):
Args:
reference (tuple[int, int], optional): Joint ID tuple to consider as reference.
Defaults to (8, 9).
threshold (float, optional): Threshold in percentage of the considered reference limb size.
ratio (float, optional): Threshold in percentage of the considered reference limb size.
Defaults to 0.5/50%.
name (str, optional): Tensor name. Defaults to None.
dtype (tf.dtypes, optional): Tensor data type. Defaults to None.
Expand All @@ -102,15 +102,15 @@ class PercentageOfCorrectKeypoints(keras.metrics.Metric):
def __init__(
self,
reference: tuple[int, int] = (8, 9),
threshold: float = 0.5,
ratio: float = 0.5,
name=None,
dtype=None,
intermediate_supervision: bool = True,
**kwargs
) -> None:
"""See help(PercentageOfCorrectKeypoints)"""
super().__init__(name, dtype, **kwargs)
self.threshold = threshold
self.ratio = ratio
self.reference = reference
self.correct_keypoints = self.add_weight(
name="correct_keypoints", initializer="zeros"
Expand Down Expand Up @@ -151,7 +151,7 @@ def update_state(self, y_true, y_pred, *args, **kwargs):
# We apply the thresholding condition
correct_keypoints = tf.reduce_sum(
tf.cast(
distance < (tf.expand_dims(reference_distance, -1) * self.threshold),
distance < (tf.expand_dims(reference_distance, -1) * self.ratio),
dtype=tf.dtypes.int32,
)
)
Expand All @@ -165,3 +165,111 @@ def result(self, *args, **kwargs):
def reset_states(self) -> None:
self.correct_keypoints.assign(0)
self.total_keypoints.assign(0)


class ObjectKeypointSimilarity(keras.metrics.Metric):
"""ObjectKeypointSimilarity metric measures if predicted keypoint and true joint are within a distance threshold
OKS is commonly used in the COCO keypoint challenge as an evaluation metric. It is calculated from
the distance between predicted points and ground truth points normalized by the scale of the person.
Scale and Keypoint constant needed to equalize the importance of each keypoint: neck location more precise than hip location.
Args:
name (str, optional): Tensor name. Defaults to None.
dtype (tf.dtypes, optional): Tensor data type. Defaults to None.
keypoints_constant (list, optional): Keypoints constant. Defaults to None.
intermediate_supervision (bool, optional): Whether or not the intermediate supervision
is activated.
Defaults to True.
compute_visibility_flags (bool, optional): Compute visibility flags from ground truth.
Defaults to True.
Notes:
For each object, ground truth keypoints have the form [x1,y1,v1,...,xk,yk,vk],
where x,y are the keypoint locations and v is a visibility flag defined as v=0: not labeled,
v=1: labeled but not visible, and v=2: labeled and visible
We define the object keypoint similarity (OKS) as: `OKS = Σi[exp(-di^2/ {2 s^2 κi^2} )δ(vi>0)] / Σi[δ(vi>0)]`
The `di` are the Euclidean distances between each corresponding ground truth and detected keypoint
and the `vi` are the visibility flags of the ground truth (the detector's predicted `vi` are not used).
To compute OKS, we pass the `di` through an unnormalized Guassian with standard deviation `sκi`,
where `s` is the object scale and `κi` is a per-keypont constant that controls falloff.
For each keypoint this yields a keypoint similarity that ranges between 0 and 1.
These similarities are averaged over all labeled keypoints (keypoints for which `vi>0`).
Predicted keypoints that are not labeled (`vi=0`) do not affect the OKS. Perfect predictions will have
OKS=1 and predictions for which all keypoints are off by more than a few standard deviations sκi will have OKS~0.
"""

def __init__(
self,
name=None,
dtype=None,
keypoints_constants: list = None,
intermediate_supervision: bool = True,
compute_visibility_flags: bool = True,
**kwargs
) -> None:
"""See help(ObjectKeypointSimilarity)"""
super().__init__(name, dtype, **kwargs)
self.oks_sum = self.add_weight(name="oks_sum", initializer="zeros")
self.samples = self.add_weight(name="samples", initializer="zeros")
self.intermediate_supervision = intermediate_supervision
self.compute_visibility_flags = compute_visibility_flags
self.keypoints_constants = keypoints_constants
if self.keypoints_constants is None:
# Set default value
pass

def reset_states(self) -> None:
self.oks_sum.assign(0.0)
self.samples.assign(0.0)

def argmax_tensor(self, tensor):
return tf_dynamic_matrix_argmax(
tensor,
intermediate_supervision=self.intermediate_supervision,
keepdims=True,
)

def get_visibility(self, y_true):
# TODO Implement Get Visibility Flags HERE
raise NotImplementedError

def get_scale(self, y_true):
# TODO Implement Get Object Scale HERE
raise NotImplementedError

def oks(self, distance, visibility_flags, scale):
# Compute the L2/Euclidean Distance
# distances = np.linalg.norm(y_pred - y_true, axis=-1)
# # Compute the exponential part of the equation
# exp_vector = np.exp(-(distances**2) / (2 * (SCALE**2) * (KAPPA**2)))
# # The numerator expression
# numerator = np.dot(exp_vector, visibility.astype(bool).astype(int))
# # The denominator expression
# denominator = np.sum(visibility.astype(bool).astype(int))
# return numerator / denominator
raise NotImplementedError

def update_state(self, y_true, y_pred, *args, **kwargs):
ground_truth_joints = self.argmax_tensor(y_true)
predicted_joints = self.argmax_tensor(y_pred)
# We compute distance between ground truth and prediction
distance = tf.norm(
tf.cast(ground_truth_joints - predicted_joints, dtype=tf.dtypes.float32),
ord=2,
axis=-1,
)
# We generate visibility tensor and scale scalar
visibility = self.get_visibility(ground_truth_joints)
scales = self.get_scale(ground_truth_joints)
# We compute value to add to weights
oks = self.oks(distance, visibility_flags=visibility, scale=scales)
total_keypoints = tf.reduce_prod(
tf.cast(tf.shape(distance)[0], dtype=tf.dtypes.float32)
)
oks_sum = tf.reduce_sum(oks)
# Add to weight
self.oks_sum.assign_add(oks_sum)
self.samples.assign_add(total_keypoints)

0 comments on commit d9c8ade

Please sign in to comment.