Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: typing #330

Merged
merged 29 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ ignore_errors = True
[mypy-torchmetrics.classification.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.functional.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.regression.*]
ignore_errors = True
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,6 @@ def test_error_on_different_mode():

def test_error_multiclass_no_num_classes():
with pytest.raises(
ValueError, match="Detected input to ``multiclass`` but you did not provide ``num_classes`` argument"
ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument"
):
_ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20, )))
6 changes: 3 additions & 3 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class TestIoU(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
def test_iou(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_class_metric_test(
ddp=ddp,
Expand All @@ -119,7 +119,7 @@ def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes
}
)

def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes):
def test_iou_functional(self, reduction, preds, target, sk_metric, num_classes):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_functional_metric_test(
preds,
Expand All @@ -133,7 +133,7 @@ def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric,
}
)

def test_confusion_matrix_differentiability(self, reduction, preds, target, sk_metric, num_classes):
def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_kldivergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_kldivergence_functional(self, reduction, p, q, log_prob):
metric_args=dict(log_prob=log_prob, reduction=reduction),
)

def test_kldivergence_differentiabilit(self, reduction, p, q, log_prob):
def test_kldivergence_differentiability(self, reduction, p, q, log_prob):
self.run_differentiability_test(
p,
q,
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, nu
metric_args=dict(adjusted=adjusted, multioutput=multioutput),
)

def test_r2_differentiabilit(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
def test_r2_differentiability(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
self.run_differentiability_test(
preds=preds,
target=target,
Expand Down
14 changes: 10 additions & 4 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod


def _check_subset_validity(mode):
def _check_subset_validity(mode: DataType) -> bool:
return mode in (DataType.MULTILABEL, DataType.MULTIDIM_MULTICLASS)


Expand All @@ -42,8 +42,8 @@ def _mode(
def _accuracy_update(
preds: Tensor,
target: Tensor,
reduce: str,
mdmc_reduce: str,
reduce: Optional[str],
mdmc_reduce: Optional[str],
threshold: float,
num_classes: Optional[int],
top_k: Optional[int],
Expand All @@ -70,7 +70,13 @@ def _accuracy_update(


def _accuracy_compute(
tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str, mdmc_average: str, mode: DataType
tp: Tensor,
fp: Tensor,
tn: Tensor,
fn: Tensor,
average: Optional[str],
mdmc_average: Optional[str],
mode: DataType,
) -> Tensor:
simple_average = [AverageMethod.MICRO, AverageMethod.SAMPLES]
if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL:
Expand Down
37 changes: 19 additions & 18 deletions torchmetrics/functional/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,30 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
return x, y


@torch.no_grad()
def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor:
with torch.no_grad():
return direction * torch.trapz(y, x)
auc: Tensor = torch.trapz(y, x) * direction
return auc


@torch.no_grad()
def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
with torch.no_grad():
if reorder:
# TODO: include stable=True arg when pytorch v1.9 is released
x, x_idx = torch.sort(x)
y = y[x_idx]

dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx <= 0).all():
direction = -1.
else:
raise ValueError(
"The `x` tensor is neither increasing or decreasing. Try setting the reorder argument to `True`."
)
if reorder:
# TODO: include stable=True arg when pytorch v1.9 is released
x, x_idx = torch.sort(x)
y = y[x_idx]

dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx <= 0).all():
direction = -1.
else:
direction = 1.
return _auc_compute_without_check(x, y, direction)
raise ValueError(
"The `x` tensor is neither increasing or decreasing. Try setting the reorder argument to `True`."
)
else:
direction = 1.
return _auc_compute_without_check(x, y, direction)


def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
Expand Down
23 changes: 11 additions & 12 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _auroc_compute(
if mode == 'multi-label':
if average == AverageMethod.MICRO:
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
else:
elif num_classes:
# for multilabel we iteratively evaluate roc in a binary fashion
output = [
roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
Expand All @@ -86,7 +86,7 @@ def _auroc_compute(
tpr = [o[1] for o in output]
Borda marked this conversation as resolved.
Show resolved Hide resolved
else:
if mode != 'binary' and num_classes is None:
raise ValueError('Detected input to ``multiclass`` but you did not provide ``num_classes`` argument')
raise ValueError('Detected input to `multiclass` but you did not provide `num_classes` argument')
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)

# calculate standard roc auc score
Expand All @@ -99,7 +99,7 @@ def _auroc_compute(

# calculate average
if average == AverageMethod.NONE:
return auc_scores
return tensor(auc_scores)
Borda marked this conversation as resolved.
Show resolved Hide resolved
if average == AverageMethod.MACRO:
return torch.mean(torch.stack(auc_scores))
if average == AverageMethod.WEIGHTED:
Expand All @@ -117,21 +117,20 @@ def _auroc_compute(

return _auc_compute_without_check(fpr, tpr, 1.0)

max_fpr = tensor(max_fpr, device=fpr.device)
_device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device
max_area: Tensor = tensor(max_fpr, device=_device)
# Add a single point at max_fpr and interpolate its tpr value
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
stop = torch.bucketize(max_area, fpr, out_int32=True, right=True)
weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight)
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
fpr = torch.cat([fpr[:stop], max_fpr.view(1)])
fpr = torch.cat([fpr[:stop], max_area.view(1)])

# Compute partial AUC
partial_auc = _auc_compute_without_check(fpr, tpr, 1.0)

# McClish correction: standardize result to be 0.5 if non-discriminant
# and 1 if maximal
min_area = 0.5 * max_fpr**2
max_area = max_fpr
# McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal
min_area: Tensor = 0.5 * max_area**2
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))


Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def _average_precision_update(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, int]:
) -> Tuple[Tensor, Tensor, int, Optional[int]]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)


def _average_precision_compute(
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: int,
pos_label: Optional[int],
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
# todo: `sample_weights` is unused
Expand Down Expand Up @@ -102,5 +102,6 @@ def average_precision(
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]

"""
# fixme: `sample_weights` is unused
Borda marked this conversation as resolved.
Show resolved Hide resolved
preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label)
return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights)
13 changes: 6 additions & 7 deletions torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,17 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)
raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}")
if normalize is not None and normalize != 'none':
confmat = confmat.float() if not confmat.is_floating_point() else confmat
cm = None
if normalize == 'true':
cm = confmat / confmat.sum(axis=1, keepdim=True)
confmat = confmat / confmat.sum(axis=1, keepdim=True)
elif normalize == 'pred':
cm = confmat / confmat.sum(axis=0, keepdim=True)
confmat = confmat / confmat.sum(axis=0, keepdim=True)
elif normalize == 'all':
cm = confmat / confmat.sum()
nan_elements = cm[torch.isnan(cm)].nelement()
confmat = confmat / confmat.sum()

nan_elements = confmat[torch.isnan(confmat)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
confmat[torch.isnan(confmat)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
return cm
return confmat


Expand Down
10 changes: 5 additions & 5 deletions torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def dice_score(

"""
num_classes = preds.shape[1]
bg = (1 - int(bool(bg)))
scores = torch.zeros(num_classes - bg, device=preds.device, dtype=torch.float32)
for i in range(bg, num_classes):
bg_inv = (1 - int(bool(bg)))
Borda marked this conversation as resolved.
Show resolved Hide resolved
scores = torch.zeros(num_classes - bg_inv, device=preds.device, dtype=torch.float32)
for i in range(bg_inv, num_classes):
if not (target == i).any():
# no foreground class
scores[i - bg] += no_fg_score
scores[i - bg_inv] += no_fg_score
continue

# TODO: rewrite to use general `stat_scores`
Expand All @@ -112,5 +112,5 @@ def dice_score(
# nan result
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score

scores[i - bg] += score_cls
scores[i - bg_inv] += score_cls
return reduce(scores, reduction=reduction)
7 changes: 3 additions & 4 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.utilities.enums import MDMCAverageMethod


def _safe_divide(num: Tensor, denom: Tensor):
def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:
""" prevent zero division """
denom[denom == 0.] = 1
return num / denom
Expand Down Expand Up @@ -186,9 +186,8 @@ def fbeta(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

allowed_mdmc_average = list(MDMCAverageMethod) + [None]
if mdmc_average not in allowed_mdmc_average:
raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.")
if mdmc_average is not None and MDMCAverageMethod.from_str(mdmc_average) is None:
raise ValueError(f"The `mdmc_average` has to be one of {list(MDMCAverageMethod)}, got {mdmc_average}.")

if average in [AvgMethod.MACRO, AvgMethod.WEIGHTED, AvgMethod.NONE] and (not num_classes or num_classes < 1):
raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.")
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _iou_from_confmat(
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = 'elementwise_mean',
):
) -> Tensor:
intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

Expand Down
Loading