diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 0669a13232d..204b164c74a 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -37,7 +37,7 @@ jobs: run: sudo apt-get install pandoc - name: Install pip dependencies run: | - pip install .[train] + pip install . pip install -r docs/requirements.txt - name: Run sphinx checks run: cd docs && make html diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a322036dee3..6f0d21c0860 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -25,7 +25,7 @@ jobs: - name: Install pip dependencies run: | pip install gdal tqdm # TODO: these deps shouldn't be needed - pip install .[datasets,tests,train] + pip install .[datasets,tests] pip install -r docs/requirements.txt - name: Run notebook checks env: @@ -42,6 +42,6 @@ jobs: with: python-version: 3.9 - name: Install pip dependencies - run: pip install .[datasets,tests,train] + run: pip install .[datasets,tests] - name: Run integration checks run: pytest -m slow diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ff44a7c7754..147a915a1bb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -20,9 +20,25 @@ jobs: - name: Install pip dependencies run: | pip install cython numpy # needed for pycocotools - pip install .[datasets,tests,train] + pip install .[datasets,tests] - name: Run mypy checks run: mypy . + datasets: + name: datasets + runs-on: ubuntu-latest + steps: + - name: Clone repo + uses: actions/checkout@v2 + - name: Set up python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install pip dependencies + run: | + pip install cython numpy # needed for pycocotools + pip install .[tests] + - name: Run pytest checks + run: pytest --cov=torchgeo --cov-report=xml pytest: name: pytest runs-on: ${{ matrix.os }} @@ -64,7 +80,7 @@ jobs: - name: Install pip dependencies run: | pip install cython numpy # needed for pycocotools - pip install .[datasets,tests,train] + pip install .[datasets,tests] - name: Run pytest checks run: pytest --cov=torchgeo --cov-report=xml - name: Report coverage diff --git a/setup.cfg b/setup.cfg index 9f506832162..4ac1ecdbe56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,8 @@ install_requires = kornia>=0.5.4 matplotlib numpy + # omegaconf 2.1+ required for to_object method + omegaconf>=2.1 # pillow 2.9+ required for height attribute pillow>=2.9 # pyproj 2.2+ required for CRS object @@ -51,8 +53,13 @@ install_requires = scikit-learn>=0.18 # shapely 1.3+ required for Python 3 support shapely>=1.3 + # segmentation-models-pytorch 0.2+ required for smp.losses module + segmentation-models-pytorch>=0.2 + # timm 0.2.1+ required for `features_only` option in create_model + timm>=0.2.1 # torch 1.7+ required for typing torch>=1.7 + torchmetrics # torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks torchvision>=0.10 python_requires = >= 3.6 @@ -65,9 +72,9 @@ include = torchgeo* # Optional dataset requirements datasets = h5py - # loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support - laspy>=2.0.0 - # open3d v0.11.2 last version for tests to pass + # laspy 2+ required for Python 3.6+ support + laspy>=2 + # open3d 0.11.2+ required to avoid GLFW error: # https://github.com/isl-org/Open3D/issues/1550 open3d>=0.11.2 opencv-python @@ -81,15 +88,6 @@ datasets = rarfile>=3 # scipy 0.9+ required for scipy.io.wavfile.read scipy>=0.9 -# Optional trainer requirements -train = - # omegaconf 2.1+ required for to_object method - omegaconf>=2.1 - # segmentation-models-pytorch 0.2+ required for smp.losses module - segmentation-models-pytorch>=0.2 - # timm 0.2.1+ required for `features_only` option in create_model - timm>=0.2.1 - torchmetrics # Optional developer requirements style = # black 21+ required for Python 3.9 support @@ -101,6 +99,7 @@ style = isort[colors]>=5.8 # pydocstyle 6.1+ required for pyproject.toml support pydocstyle[toml]>=6.1 +# Optional testing requirements tests = # mypy 0.900+ required for pyproject.toml support mypy>=0.900 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 4fcd1b16753..30387142b53 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -25,7 +25,6 @@ class TestADVANCE: def dataset( self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path ) -> ADVANCE: - pytest.importorskip("scipy", minversion="0.9.0") monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.utils, "download_url", download_url ) @@ -57,6 +56,7 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: ) def test_getitem(self, dataset: ADVANCE) -> None: + pytest.importorskip("scipy", minversion="0.9.0") x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index 9f39e47e619..55e26af648d 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -19,6 +19,9 @@ import torchgeo.datasets.utils from torchgeo.datasets import IDTReeS +pytest.importorskip("pandas", minversion="0.19.1") +pytest.importorskip("laspy", minversion="2") + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -32,8 +35,6 @@ def dataset( tmp_path: Path, request: SubRequest, ) -> IDTReeS: - pytest.importorskip("pandas") - pytest.importorskip("laspy") monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.idtrees, "download_url", download_url ) @@ -124,11 +125,11 @@ def test_mock_missing_module( ImportError, match=f"{package} is not installed and is required to use this dataset", ): - IDTReeS(dataset.root, download=True, checksum=True) - else: + IDTReeS(dataset.root, dataset.split, dataset.task) + elif package in ["open3d"]: with pytest.raises( ImportError, - match=f"{package} is not installed and is required to use this dataset", + match=f"{package} is not installed and is required to plot point cloud", ): dataset.plot_las(0) @@ -153,7 +154,7 @@ def test_plot(self, dataset: IDTReeS) -> None: reason="segmentation fault on macOS and windows", ) def test_plot_las(self, dataset: IDTReeS) -> None: - pytest.importorskip("open3d") + pytest.importorskip("open3d", minversion="0.11.2") vis = dataset.plot_las(index=0, colormap="BrBG") vis.close() vis = dataset.plot_las(index=0, colormap=None) diff --git a/tests/datasets/test_nwpu.py b/tests/datasets/test_nwpu.py index 9007132c32c..43c9e260c8f 100644 --- a/tests/datasets/test_nwpu.py +++ b/tests/datasets/test_nwpu.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import builtins import os import shutil import sys from pathlib import Path -from typing import Generator +from typing import Any, Generator import pytest import torch @@ -17,7 +18,6 @@ import torchgeo.datasets.utils from torchgeo.datasets import VHR10 -pytest.importorskip("rarfile") pytest.importorskip("pycocotools") @@ -35,6 +35,7 @@ def dataset( tmp_path: Path, request: SubRequest, ) -> VHR10: + pytest.importorskip("rarfile", minversion="3") monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.nwpu, "download_url", download_url ) @@ -54,6 +55,21 @@ def dataset( transforms = nn.Identity() # type: ignore[attr-defined] return VHR10(root, split, transforms, download=True, checksum=True) + @pytest.fixture + def mock_missing_module( + self, monkeypatch: Generator[MonkeyPatch, None, None] + ) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "pycocotools.coco": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr( # type: ignore[attr-defined] + builtins, "__import__", mocked_import + ) + def test_getitem(self, dataset: VHR10) -> None: x = dataset[0] assert isinstance(x, dict) @@ -84,3 +100,13 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): VHR10(str(tmp_path)) + + def test_mock_missing_module( + self, dataset: VHR10, mock_missing_module: None + ) -> None: + if dataset.split == "positive": + with pytest.raises( + ImportError, + match="pycocotools is not installed and is required to use this datase", + ): + VHR10(dataset.root, dataset.split) diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index 1b68dfae89c..75ed6ee2d58 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -17,8 +17,6 @@ import torchgeo.datasets.utils from torchgeo.datasets import RESISC45, RESISC45DataModule -pytest.importorskip("rarfile") - def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -33,6 +31,8 @@ def dataset( tmp_path: Path, request: SubRequest, ) -> RESISC45: + pytest.importorskip("rarfile", minversion="3") + monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.resisc45, "download_url", download_url ) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 79aa62d7f2b..c96eb7b8c22 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import builtins import os from pathlib import Path -from typing import Generator +from typing import Any, Generator import matplotlib.pyplot as plt import pytest @@ -14,6 +15,8 @@ from torchgeo.datasets import So2Sat, So2SatDataModule +pytest.importorskip("h5py") + class TestSo2Sat: @pytest.fixture(params=["train", "validation", "test"]) @@ -32,6 +35,21 @@ def dataset( transforms = nn.Identity() # type: ignore[attr-defined] return So2Sat(root, split, transforms, checksum=True) + @pytest.fixture + def mock_missing_module( + self, monkeypatch: Generator[MonkeyPatch, None, None] + ) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "h5py": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr( # type: ignore[attr-defined] + builtins, "__import__", mocked_import + ) + def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) @@ -65,6 +83,15 @@ def test_plot(self, dataset: So2Sat) -> None: dataset.plot(x) plt.close() + def test_mock_missing_module( + self, dataset: So2Sat, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="h5py is not installed and is required to use this dataset", + ): + So2Sat(dataset.root) + class TestSo2SatDataModule: @pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"])) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 1cf13189943..9e8732deac4 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -98,7 +98,7 @@ def test_mock_missing_module(mock_missing_module: None) -> None: ], ) def test_extract_archive(src: str, tmp_path: Path) -> None: - pytest.importorskip("rarfile") + pytest.importorskip("rarfile", minversion="3") extract_archive(os.path.join("tests", "data", src), str(tmp_path)) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index a96b90867ec..54fbe74683b 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -522,7 +522,7 @@ def plot_las(self, index: int, colormap: Optional[str] = None) -> Any: import open3d # noqa: F401 except ImportError: raise ImportError( - "open3d is not installed and is required to use this dataset" + "open3d is not installed and is required to plot point clouds" ) import laspy diff --git a/torchgeo/datasets/nwpu.py b/torchgeo/datasets/nwpu.py index 03c0a44dfb2..30f75343e37 100644 --- a/torchgeo/datasets/nwpu.py +++ b/torchgeo/datasets/nwpu.py @@ -117,7 +117,12 @@ def __init__( if split == "positive": # Must be installed to parse annotations file - from pycocotools.coco import COCO + try: + from pycocotools.coco import COCO # noqa: F401 + except ImportError: + raise ImportError( + "pycocotools is not installed and is required to use this dataset" + ) self.coco = COCO( os.path.join( diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 0244bab2b12..606a73e62c0 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -132,7 +132,12 @@ def __init__( AssertionError: if ``split`` argument is invalid RuntimeError: if data is not found in ``root``, or checksums don't match """ - import h5py + try: + import h5py # noqa: F401 + except ImportError: + raise ImportError( + "h5py is not installed and is required to use this dataset" + ) assert split in ["train", "validation", "test"]