diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b6015ae5c3..be963e690e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330), - [#332](https://github.com/PyTorchLightning/metrics/pull/332)) + [#332](https://github.com/PyTorchLightning/metrics/pull/332), + [#333](https://github.com/PyTorchLightning/metrics/pull/333)) ### Deprecated diff --git a/setup.cfg b/setup.cfg index 2632b53c7a0..3f72b51364a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,16 +73,6 @@ files = torchmetrics disallow_untyped_defs = True ignore_missing_imports = True -# todo: add proper typing to this module... -[mypy-torchmetrics.image.ssim] -ignore_errors = True -[mypy-torchmetrics.image.psnr] -ignore_errors = True - # todo: add proper typing to this module... [mypy-torchmetrics.classification.*] ignore_errors = True - -# todo: add proper typing to this module... -[mypy-torchmetrics.regression.*] -ignore_errors = True diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index 6f8d0625c5f..1f3501034f8 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -68,6 +68,8 @@ class PSNR(Metric): Half precision is only support on GPU for this metric """ + min_target: Tensor + max_target: Tensor def __init__( self, @@ -110,7 +112,7 @@ def __init__( self.reduction = reduction self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. @@ -131,7 +133,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.sum_squared_error.append(sum_squared_error) self.total.append(n_obs) - def compute(self): + def compute(self) -> Tensor: """ Compute peak signal-to-noise ratio over state. """ diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index 5f8daa15135..df9c1973453 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -84,7 +84,7 @@ def __init__( self.k2 = k2 self.reduction = reduction - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. @@ -96,7 +96,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.y_pred.append(preds) self.y.append(target) - def compute(self): + def compute(self) -> Tensor: """ Computes explained variance over state. """ diff --git a/torchmetrics/regression/cosine_similarity.py b/torchmetrics/regression/cosine_similarity.py index 43eddd715cb..3f2536694cf 100644 --- a/torchmetrics/regression/cosine_similarity.py +++ b/torchmetrics/regression/cosine_similarity.py @@ -80,7 +80,7 @@ def __init__( self.add_state("target", [], dist_reduce_fx="cat") self.reduction = reduction - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update metric states with predictions and targets. diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 3edcfa96008..633035468f6 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -77,6 +77,11 @@ class ExplainedVariance(Metric): >>> explained_variance(preds, target) tensor([0.9677, 1.0000]) """ + n_obs: Tensor + sum_error: Tensor + sum_squared_error: Tensor + sum_target: Tensor + sum_squared_target: Tensor def __init__( self, @@ -97,14 +102,14 @@ def __init__( raise ValueError( f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}" ) - self.multioutput = multioutput + self.multioutput: str = multioutput self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("sum_target", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_absolute_error.py b/torchmetrics/regression/mean_absolute_error.py index 52c33016743..78b7cceddfa 100644 --- a/torchmetrics/regression/mean_absolute_error.py +++ b/torchmetrics/regression/mean_absolute_error.py @@ -66,7 +66,7 @@ def __init__( self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_absolute_percentage_error.py b/torchmetrics/regression/mean_absolute_percentage_error.py index 381cac8d564..ade6c81100f 100644 --- a/torchmetrics/regression/mean_absolute_percentage_error.py +++ b/torchmetrics/regression/mean_absolute_percentage_error.py @@ -77,7 +77,7 @@ def __init__( self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index 20b9eb475f1..94ead23e732 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -71,7 +71,7 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") self.squared = squared - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 83ce30dd856..69ef426b7b6 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -72,7 +72,7 @@ def __init__( self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 3706f54982e..6028e023e96 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -77,7 +77,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/r2score.py b/torchmetrics/regression/r2score.py index 86d8d9ec3e3..fbb6371beb4 100644 --- a/torchmetrics/regression/r2score.py +++ b/torchmetrics/regression/r2score.py @@ -123,7 +123,7 @@ def __init__( self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets. diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 50178a42ead..3778e64125c 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -75,7 +75,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Update state with predictions and targets.