Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: regression #333

Merged
merged 8 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330),
[#332](https://github.com/PyTorchLightning/metrics/pull/332))
[#332](https://github.com/PyTorchLightning/metrics/pull/332),
[#333](https://github.com/PyTorchLightning/metrics/pull/333))


### Deprecated
Expand Down
10 changes: 0 additions & 10 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ files = torchmetrics
disallow_untyped_defs = True
ignore_missing_imports = True

# todo: add proper typing to this module...
[mypy-torchmetrics.image.ssim]
ignore_errors = True
[mypy-torchmetrics.image.psnr]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.classification.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.regression.*]
ignore_errors = True
6 changes: 4 additions & 2 deletions torchmetrics/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class PSNR(Metric):
Half precision is only support on GPU for this metric

"""
min_target: Tensor
max_target: Tensor

def __init__(
self,
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(
self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand All @@ -131,7 +133,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.sum_squared_error.append(sum_squared_error)
self.total.append(n_obs)

def compute(self):
def compute(self) -> Tensor:
"""
Compute peak signal-to-noise ratio over state.
"""
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.k2 = k2
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand All @@ -96,7 +96,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.y_pred.append(preds)
self.y.append(target)

def compute(self):
def compute(self) -> Tensor:
"""
Computes explained variance over state.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self.add_state("target", [], dist_reduce_fx="cat")
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update metric states with predictions and targets.

Expand Down
9 changes: 7 additions & 2 deletions torchmetrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class ExplainedVariance(Metric):
>>> explained_variance(preds, target)
tensor([0.9677, 1.0000])
"""
n_obs: Tensor
sum_error: Tensor
sum_squared_error: Tensor
sum_target: Tensor
sum_squared_target: Tensor

def __init__(
self,
Expand All @@ -97,14 +102,14 @@ def __init__(
raise ValueError(
f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}"
)
self.multioutput = multioutput
self.multioutput: str = multioutput
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_target", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.squared = squared

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/mean_squared_log_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/regression/spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.

Expand Down