Skip to content

Commit

Permalink
Improve speed and memory consumption of binned PrecisionRecallCurve (
Browse files Browse the repository at this point in the history
…#1493)

Co-authored-by: Björn Barz <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
5 people authored Feb 23, 2023
1 parent 596cc4c commit 3df4e1b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Extend `EnumStr` raising `ValueError` for invalid value ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


- Improve speed and memory consumption of binned `PrecisionRecallCurve` with large number of samples ([#1493](https://github.com/Lightning-AI/metrics/pull/1493))


- Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538))


### Deprecated

-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,54 @@ def _binary_precision_recall_curve_update(
"""
if thresholds is None:
return preds, target
if preds.numel() <= 50_000:
update_fn = _binary_precision_recall_curve_update_vectorized
else:
update_fn = _binary_precision_recall_curve_update_loop
return update_fn(preds, target, thresholds)


def _binary_precision_recall_curve_update_vectorized(
preds: Tensor,
target: Tensor,
thresholds: Tensor,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the multi-threshold confusion matrix to calculate the pr-curve with.
This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small
numbers of samples (up to 50k) but less memory- and time-efficient for more samples.
"""
len_t = len(thresholds)
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds
unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device)
bins = _bincount(unique_mapping.flatten(), minlength=4 * len_t)
return bins.reshape(len_t, 2, 2)


def _binary_precision_recall_curve_update_loop(
preds: Tensor,
target: Tensor,
thresholds: Tensor,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the multi-threshold confusion matrix to calculate the pr-curve with.
This implementation loops over thresholds and is more memory-efficient than
`_binary_precision_recall_curve_update_vectorized`. However, it is slowwer for small
numbers of samples (up to 50k).
"""
len_t = len(thresholds)
target = target == 1
confmat = thresholds.new_empty((len_t, 2, 2), dtype=torch.int64)
# Iterate one threshold at a time to conserve memory
for i in range(len_t):
preds_t = preds >= thresholds[i]
confmat[i, 1, 1] = (target & preds_t).sum()
confmat[i, 0, 1] = ((~target) & preds_t).sum()
confmat[i, 1, 0] = (target & (~preds_t)).sum()
confmat[:, 0, 0] = len(preds_t) - confmat[:, 0, 1] - confmat[:, 1, 0] - confmat[:, 1, 1]
return confmat


def _binary_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
thresholds: Optional[Tensor],
Expand Down Expand Up @@ -409,8 +450,25 @@ def _multiclass_precision_recall_curve_update(
"""
if thresholds is None:
return preds, target
if preds.numel() * num_classes <= 1_000_000:
update_fn = _multiclass_precision_recall_curve_update_vectorized
else:
update_fn = _multiclass_precision_recall_curve_update_loop
return update_fn(preds, target, num_classes, thresholds)


def _multiclass_precision_recall_curve_update_vectorized(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Tensor,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the multi-threshold confusion matrix to calculate the pr-curve with.
This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small
numbers of samples but less memory- and time-efficient for more samples.
"""
len_t = len(thresholds)
# num_samples x num_classes x num_thresholds
preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
target_t = torch.nn.functional.one_hot(target, num_classes=num_classes)
unique_mapping = preds_t + 2 * target_t.unsqueeze(-1)
Expand All @@ -420,6 +478,31 @@ def _multiclass_precision_recall_curve_update(
return bins.reshape(len_t, num_classes, 2, 2)


def _multiclass_precision_recall_curve_update_loop(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Tensor,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Returns the state to calculate the pr-curve with.
This implementation loops over thresholds and is more memory-efficient than
`_binary_precision_recall_curve_update_vectorized`. However, it is slowwer for small
numbers of samples.
"""
len_t = len(thresholds)
target_t = torch.nn.functional.one_hot(target, num_classes=num_classes)
confmat = thresholds.new_empty((len_t, num_classes, 2, 2), dtype=torch.int64)
# Iterate one threshold at a time to conserve memory
for i in range(len_t):
preds_t = preds >= thresholds[i]
confmat[i, :, 1, 1] = (target_t & preds_t).sum(dim=0)
confmat[i, :, 0, 1] = ((~target_t) & preds_t).sum(dim=0)
confmat[i, :, 1, 0] = (target_t & (~preds_t)).sum(dim=0)
confmat[:, :, 0, 0] = len(preds_t) - confmat[:, :, 0, 1] - confmat[:, :, 1, 0] - confmat[:, :, 1, 1]
return confmat


def _multiclass_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
num_classes: int,
Expand Down

0 comments on commit 3df4e1b

Please sign in to comment.