Skip to content

Commit

Permalink
Rework list of required dependencies (#287)
Browse files Browse the repository at this point in the history
* Rework list of required dependencies

* Update open3d import error msg

* Style fixes

* Remove extra empty line

* Increase test coverage

* Fix idtrees tests
  • Loading branch information
adamjstewart authored Dec 18, 2021
1 parent 9040e72 commit be36c1e
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: sudo apt-get install pandoc
- name: Install pip dependencies
run: |
pip install .[train]
pip install .
pip install -r docs/requirements.txt
- name: Run sphinx checks
run: cd docs && make html
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install pip dependencies
run: |
pip install gdal tqdm # TODO: these deps shouldn't be needed
pip install .[datasets,tests,train]
pip install .[datasets,tests]
pip install -r docs/requirements.txt
- name: Run notebook checks
env:
Expand All @@ -42,6 +42,6 @@ jobs:
with:
python-version: 3.9
- name: Install pip dependencies
run: pip install .[datasets,tests,train]
run: pip install .[datasets,tests]
- name: Run integration checks
run: pytest -m slow
20 changes: 18 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@ jobs:
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[datasets,tests,train]
pip install .[datasets,tests]
- name: Run mypy checks
run: mypy .
datasets:
name: datasets
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@v2
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[tests]
- name: Run pytest checks
run: pytest --cov=torchgeo --cov-report=xml
pytest:
name: pytest
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -64,7 +80,7 @@ jobs:
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[datasets,tests,train]
pip install .[datasets,tests]
- name: Run pytest checks
run: pytest --cov=torchgeo --cov-report=xml
- name: Report coverage
Expand Down
23 changes: 11 additions & 12 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ install_requires =
kornia>=0.5.4
matplotlib
numpy
# omegaconf 2.1+ required for to_object method
omegaconf>=2.1
# pillow 2.9+ required for height attribute
pillow>=2.9
# pyproj 2.2+ required for CRS object
Expand All @@ -51,8 +53,13 @@ install_requires =
scikit-learn>=0.18
# shapely 1.3+ required for Python 3 support
shapely>=1.3
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# timm 0.2.1+ required for `features_only` option in create_model
timm>=0.2.1
# torch 1.7+ required for typing
torch>=1.7
torchmetrics
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
torchvision>=0.10
python_requires = >= 3.6
Expand All @@ -65,9 +72,9 @@ include = torchgeo*
# Optional dataset requirements
datasets =
h5py
# loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support
laspy>=2.0.0
# open3d v0.11.2 last version for tests to pass
# laspy 2+ required for Python 3.6+ support
laspy>=2
# open3d 0.11.2+ required to avoid GLFW error:
# https://github.com/isl-org/Open3D/issues/1550
open3d>=0.11.2
opencv-python
Expand All @@ -81,15 +88,6 @@ datasets =
rarfile>=3
# scipy 0.9+ required for scipy.io.wavfile.read
scipy>=0.9
# Optional trainer requirements
train =
# omegaconf 2.1+ required for to_object method
omegaconf>=2.1
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# timm 0.2.1+ required for `features_only` option in create_model
timm>=0.2.1
torchmetrics
# Optional developer requirements
style =
# black 21+ required for Python 3.9 support
Expand All @@ -101,6 +99,7 @@ style =
isort[colors]>=5.8
# pydocstyle 6.1+ required for pyproject.toml support
pydocstyle[toml]>=6.1
# Optional testing requirements
tests =
# mypy 0.900+ required for pyproject.toml support
mypy>=0.900
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class TestADVANCE:
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> ADVANCE:
pytest.importorskip("scipy", minversion="0.9.0")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
Expand Down Expand Up @@ -57,6 +56,7 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
)

def test_getitem(self, dataset: ADVANCE) -> None:
pytest.importorskip("scipy", minversion="0.9.0")
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
Expand Down
13 changes: 7 additions & 6 deletions tests/datasets/test_idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import torchgeo.datasets.utils
from torchgeo.datasets import IDTReeS

pytest.importorskip("pandas", minversion="0.19.1")
pytest.importorskip("laspy", minversion="2")


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
Expand All @@ -32,8 +35,6 @@ def dataset(
tmp_path: Path,
request: SubRequest,
) -> IDTReeS:
pytest.importorskip("pandas")
pytest.importorskip("laspy")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.idtrees, "download_url", download_url
)
Expand Down Expand Up @@ -124,11 +125,11 @@ def test_mock_missing_module(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
IDTReeS(dataset.root, download=True, checksum=True)
else:
IDTReeS(dataset.root, dataset.split, dataset.task)
elif package in ["open3d"]:
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
match=f"{package} is not installed and is required to plot point cloud",
):
dataset.plot_las(0)

Expand All @@ -153,7 +154,7 @@ def test_plot(self, dataset: IDTReeS) -> None:
reason="segmentation fault on macOS and windows",
)
def test_plot_las(self, dataset: IDTReeS) -> None:
pytest.importorskip("open3d")
pytest.importorskip("open3d", minversion="0.11.2")
vis = dataset.plot_las(index=0, colormap="BrBG")
vis.close()
vis = dataset.plot_las(index=0, colormap=None)
Expand Down
30 changes: 28 additions & 2 deletions tests/datasets/test_nwpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
import sys
from pathlib import Path
from typing import Generator
from typing import Any, Generator

import pytest
import torch
Expand All @@ -17,7 +18,6 @@
import torchgeo.datasets.utils
from torchgeo.datasets import VHR10

pytest.importorskip("rarfile")
pytest.importorskip("pycocotools")


Expand All @@ -35,6 +35,7 @@ def dataset(
tmp_path: Path,
request: SubRequest,
) -> VHR10:
pytest.importorskip("rarfile", minversion="3")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.nwpu, "download_url", download_url
)
Expand All @@ -54,6 +55,21 @@ def dataset(
transforms = nn.Identity() # type: ignore[attr-defined]
return VHR10(root, split, transforms, download=True, checksum=True)

@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "pycocotools.coco":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)

def test_getitem(self, dataset: VHR10) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down Expand Up @@ -84,3 +100,13 @@ def test_invalid_split(self) -> None:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
VHR10(str(tmp_path))

def test_mock_missing_module(
self, dataset: VHR10, mock_missing_module: None
) -> None:
if dataset.split == "positive":
with pytest.raises(
ImportError,
match="pycocotools is not installed and is required to use this datase",
):
VHR10(dataset.root, dataset.split)
4 changes: 2 additions & 2 deletions tests/datasets/test_resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45, RESISC45DataModule

pytest.importorskip("rarfile")


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
Expand All @@ -33,6 +31,8 @@ def dataset(
tmp_path: Path,
request: SubRequest,
) -> RESISC45:
pytest.importorskip("rarfile", minversion="3")

monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.resisc45, "download_url", download_url
)
Expand Down
29 changes: 28 additions & 1 deletion tests/datasets/test_so2sat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
from pathlib import Path
from typing import Generator
from typing import Any, Generator

import matplotlib.pyplot as plt
import pytest
Expand All @@ -14,6 +15,8 @@

from torchgeo.datasets import So2Sat, So2SatDataModule

pytest.importorskip("h5py")


class TestSo2Sat:
@pytest.fixture(params=["train", "validation", "test"])
Expand All @@ -32,6 +35,21 @@ def dataset(
transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True)

@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)

def test_getitem(self, dataset: So2Sat) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down Expand Up @@ -65,6 +83,15 @@ def test_plot(self, dataset: So2Sat) -> None:
dataset.plot(x)
plt.close()

def test_mock_missing_module(
self, dataset: So2Sat, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="h5py is not installed and is required to use this dataset",
):
So2Sat(dataset.root)


class TestSo2SatDataModule:
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_mock_missing_module(mock_missing_module: None) -> None:
],
)
def test_extract_archive(src: str, tmp_path: Path) -> None:
pytest.importorskip("rarfile")
pytest.importorskip("rarfile", minversion="3")
extract_archive(os.path.join("tests", "data", src), str(tmp_path))


Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def plot_las(self, index: int, colormap: Optional[str] = None) -> Any:
import open3d # noqa: F401
except ImportError:
raise ImportError(
"open3d is not installed and is required to use this dataset"
"open3d is not installed and is required to plot point clouds"
)
import laspy

Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/nwpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def __init__(

if split == "positive":
# Must be installed to parse annotations file
from pycocotools.coco import COCO
try:
from pycocotools.coco import COCO # noqa: F401
except ImportError:
raise ImportError(
"pycocotools is not installed and is required to use this dataset"
)

self.coco = COCO(
os.path.join(
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def __init__(
AssertionError: if ``split`` argument is invalid
RuntimeError: if data is not found in ``root``, or checksums don't match
"""
import h5py
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)

assert split in ["train", "validation", "test"]

Expand Down

0 comments on commit be36c1e

Please sign in to comment.