Skip to content

Commit

Permalink
typing n/m (#1879)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Jul 4, 2023
1 parent 1bbda8d commit 62b7d97
Show file tree
Hide file tree
Showing 25 changed files with 93 additions and 118 deletions.
11 changes: 0 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ warn_no_return = "False"
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
module = [
"torchmetrics.classification.accuracy",
"torchmetrics.classification.auroc",
"torchmetrics.classification.average_precision",
"torchmetrics.classification.calibration_error",
"torchmetrics.classification.cohen_kappa",
"torchmetrics.classification.confusion_matrix",
Expand All @@ -178,21 +175,13 @@ module = [
"torchmetrics.classification.jaccard",
"torchmetrics.classification.matthews_corrcoef",
"torchmetrics.classification.precision_recall",
"torchmetrics.classification.precision_recall_curve",
"torchmetrics.classification.ranking",
"torchmetrics.classification.recall_at_fixed_precision",
"torchmetrics.classification.roc",
"torchmetrics.classification.specificity",
"torchmetrics.classification.stat_scores",
"torchmetrics.detection._mean_ap",
"torchmetrics.detection.mean_ap",
"torchmetrics.functional.classification.calibration_error",
"torchmetrics.functional.classification.confusion_matrix",
"torchmetrics.functional.classification.f_beta",
"torchmetrics.functional.classification.group_fairness",
"torchmetrics.functional.classification.precision_recall_curve",
"torchmetrics.functional.classification.ranking",
"torchmetrics.functional.classification.recall_at_fixed_precision",
"torchmetrics.functional.image.psnr",
"torchmetrics.functional.image.ssim",
"torchmetrics.image.psnr",
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class Accuracy:
tensor(0.6667)
"""

def __new__(
def __new__( # type: ignore[misc]
cls,
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -503,4 +503,4 @@ def __new__(
f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}"
)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy
raise ValueError(f"Not handled value: {task}")
22 changes: 11 additions & 11 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def __init__(
_binary_auroc_arg_validation(max_fpr, thresholds, ignore_index)
self.max_fpr = max_fpr

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _binary_auroc_compute(state, self.thresholds, self.max_fpr)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -260,12 +260,12 @@ def __init__(
self.average = average
self.validate_args = validate_args

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -407,12 +407,12 @@ def __init__(
self.average = average
self.validate_args = validate_args

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -485,7 +485,7 @@ class AUROC:
tensor(0.7778)
"""

def __new__(
def __new__( # type: ignore[misc]
cls,
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand All @@ -510,4 +510,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelAUROC(num_labels, average, **kwargs)
return None
raise ValueError(f"Task {task} not supported!")
22 changes: 11 additions & 11 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _binary_average_precision_compute(state, self.thresholds)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -259,12 +259,12 @@ def __init__(
self.average = average
self.validate_args = validate_args

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -411,14 +411,14 @@ def __init__(
self.average = average
self.validate_args = validate_args

def compute(self) -> Tensor:
def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_average_precision_compute(
state, self.num_labels, self.average, self.thresholds, self.ignore_index
)

def plot(
def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down Expand Up @@ -495,7 +495,7 @@ class AveragePrecision:
tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
"""

def __new__(
def __new__( # type: ignore[misc]
cls,
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand All @@ -519,4 +519,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelAveragePrecision(num_labels, average, **kwargs)
return None
raise ValueError(f"Task {task} not supported!")
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,4 @@ def __new__(
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
return MulticlassCalibrationError(num_classes, **kwargs)
raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy
raise ValueError(f"Not handled value: {task}")
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs)
raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy
raise ValueError(f"Not handled value: {task}")
18 changes: 3 additions & 15 deletions src/torchmetrics/classification/precision_fixed_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _binary_recall_at_fixed_precision_compute(
state, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall
)
Expand Down Expand Up @@ -271,11 +267,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_recall_at_fixed_precision_arg_compute(
state, self.num_classes, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall
)
Expand Down Expand Up @@ -418,11 +410,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_recall_at_fixed_precision_arg_compute(
state, self.num_labels, self.thresholds, self.ignore_index, self.min_recall, reduce_fn=_precision_at_recall
)
Expand Down
22 changes: 17 additions & 5 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ class BinaryPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

preds: List[Tensor]
target: List[Tensor]
confmat: Tensor

def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand Down Expand Up @@ -164,7 +168,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _binary_precision_recall_curve_compute(state, self.thresholds)

def plot(
Expand Down Expand Up @@ -289,6 +293,10 @@ class MulticlassPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

preds: List[Tensor]
target: List[Tensor]
confmat: Tensor

def __init__(
self,
num_classes: int,
Expand Down Expand Up @@ -334,7 +342,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds)

def plot(
Expand Down Expand Up @@ -468,6 +476,10 @@ class MultilabelPrecisionRecallCurve(Metric):
higher_is_better: Optional[bool] = None
full_state_update: bool = False

preds: List[Tensor]
target: List[Tensor]
confmat: Tensor

def __init__(
self,
num_labels: int,
Expand Down Expand Up @@ -513,7 +525,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index)

def plot(
Expand Down Expand Up @@ -595,7 +607,7 @@ class PrecisionRecallCurve:
tensor(0.0500)]
"""

def __new__(
def __new__( # type: ignore[misc]
cls,
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand All @@ -618,4 +630,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelPrecisionRecallCurve(num_labels, **kwargs)
return None
raise ValueError(f"Task {task} not supported!")
18 changes: 3 additions & 15 deletions src/torchmetrics/classification/recall_fixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision)

def plot( # type: ignore[override]
Expand Down Expand Up @@ -266,11 +262,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_recall_at_fixed_precision_arg_compute(
state, self.num_classes, self.thresholds, self.min_precision
)
Expand Down Expand Up @@ -413,11 +405,7 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = (
(dim_zero_cat(self.preds), dim_zero_cat(self.target)) # type: ignore[arg-type]
if self.thresholds is None
else self.confmat
)
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_recall_at_fixed_precision_arg_compute(
state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision
)
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/classification/specificity_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = [_cat(self.preds), _cat(self.target)] if self.thresholds is None else self.confmat # type: ignore
return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity) # type: ignore
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity)


class MulticlassSpecificityAtSensitivity(MulticlassPrecisionRecallCurve):
Expand Down Expand Up @@ -217,9 +217,9 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = [_cat(self.preds), _cat(self.target)] if self.thresholds is None else self.confmat # type: ignore
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_specificity_at_sensitivity_compute(
state, self.num_classes, self.thresholds, self.min_sensitivity # type: ignore
state, self.num_classes, self.thresholds, self.min_sensitivity
)


Expand Down Expand Up @@ -314,9 +314,9 @@ def __init__(

def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override]
"""Compute metric."""
state = [_cat(self.preds), _cat(self.target)] if self.thresholds is None else self.confmat # type: ignore
state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
return _multilabel_specificity_at_sensitivity_compute(
state, self.num_labels, self.thresholds, self.ignore_index, self.min_sensitivity # type: ignore
state, self.num_labels, self.thresholds, self.ignore_index, self.min_sensitivity
)


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,4 @@ def accuracy(
return multilabel_accuracy(
preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
)
raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy
raise ValueError(f"Not handled value: {task}")
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _binary_calibration_error_tensor_validation(
)


def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tensor:
def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
confidences, accuracies = preds, target
return confidences, accuracies

Expand Down Expand Up @@ -235,7 +235,7 @@ def _multiclass_calibration_error_tensor_validation(
def _multiclass_calibration_error_update(
preds: Tensor,
target: Tensor,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)
confidences, predictions = preds.max(dim=1)
Expand Down
Loading

0 comments on commit 62b7d97

Please sign in to comment.