diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index 7f7212b6e9e..dae8f55b5c3 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -12,7 +12,7 @@ import torch from torch import Tensor -from torch.nn import Module # type: ignore[attr-defined] +from torch.nn.modules import Module # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 @@ -90,7 +90,7 @@ def ndwi(green: Tensor, nir: Tensor) -> Tensor: return (green - nir) / ((green + nir) + _EPSILON) -class AppendNDBI(Module): # type: ignore[misc,name-defined] +class AppendNDBI(Module): """Normalized Difference Built-up Index (NDBI). If you use this dataset in your research, please cite the following paper: @@ -132,7 +132,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNBR(Module): # type: ignore[misc,name-defined] +class AppendNBR(Module): """Normalized Burn Ratio (NBR). .. versionadded:: 0.2.0 @@ -172,7 +172,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDSI(Module): # type: ignore[misc,name-defined] +class AppendNDSI(Module): """Normalized Difference Snow Index (NDSI). If you use this dataset in your research, please cite the following paper: @@ -214,7 +214,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDVI(Module): # type: ignore[misc,name-defined] +class AppendNDVI(Module): """Normalized Difference Vegetation Index (NDVI). If you use this dataset in your research, please cite the following paper: @@ -256,7 +256,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: return sample -class AppendNDWI(Module): # type: ignore[misc,name-defined] +class AppendNDWI(Module): """Normalized Difference Water Index (NDWI). If you use this dataset in your research, please cite the following paper: diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index b1b8cda3af8..67be28c31bb 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,14 +8,14 @@ import kornia.augmentation as K import torch from torch import Tensor -from torch.nn import Module # type: ignore[attr-defined] +from torch.nn.modules import Module # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 Module.__module__ = "torch.nn" -class AugmentationSequential(Module): # type: ignore[misc] +class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts.""" def __init__(self, *args: Module, data_keys: List[str]) -> None: