Skip to content

Commit

Permalink
Rework list of required dependencies (#287)
Browse files Browse the repository at this point in the history
* Rework list of required dependencies

* Update open3d import error msg

* Style fixes

* Remove extra empty line

* Increase test coverage

* Fix idtrees tests
  • Loading branch information
adamjstewart committed Dec 19, 2021
1 parent ae11f10 commit 01ae2db
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: sudo apt-get install pandoc
- name: Install pip dependencies
run: |
pip install .[train]
pip install .
pip install -r docs/requirements.txt
- name: Run sphinx checks
run: cd docs && make html
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install pip dependencies
run: |
pip install gdal tqdm # TODO: these deps shouldn't be needed
pip install .[datasets,tests,train]
pip install .[datasets,tests]
pip install -r docs/requirements.txt
- name: Run notebook checks
env:
Expand All @@ -42,6 +42,6 @@ jobs:
with:
python-version: 3.9
- name: Install pip dependencies
run: pip install .[datasets,tests,train]
run: pip install .[datasets,tests]
- name: Run integration checks
run: pytest -m slow
20 changes: 18 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@ jobs:
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[datasets,tests,train]
pip install .[datasets,tests]
- name: Run mypy checks
run: mypy .
datasets:
name: datasets
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@v2
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[tests]
- name: Run pytest checks
run: pytest --cov=torchgeo --cov-report=xml
pytest:
name: pytest
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -64,7 +80,7 @@ jobs:
- name: Install pip dependencies
run: |
pip install cython numpy # needed for pycocotools
pip install .[datasets,tests,train]
pip install .[datasets,tests]
- name: Run pytest checks
run: pytest --cov=torchgeo --cov-report=xml
- name: Report coverage
Expand Down
17 changes: 8 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ install_requires =
kornia>=0.5.4
matplotlib
numpy
# omegaconf 2.1+ required for to_object method
omegaconf>=2.1
# pillow 2.9+ required for height attribute
pillow>=2.9
# pyproj 2.2+ required for CRS object
Expand All @@ -49,8 +51,13 @@ install_requires =
scikit-learn>=0.18
# shapely 1.3+ required for Python 3 support
shapely>=1.3
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# timm 0.2.1+ required for `features_only` option in create_model
timm>=0.2.1
# torch 1.7+ required for typing
torch>=1.7
torchmetrics
# torchvision 0.3+ required for download_file_from_google_drive
torchvision>=0.3
python_requires = >= 3.6
Expand All @@ -72,15 +79,6 @@ datasets =
rarfile>=3
# scipy 0.9+ required for scipy.io.wavfile.read
scipy>=0.9
# Optional trainer requirements
train =
# omegaconf 2.1+ required for to_object method
omegaconf>=2.1
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# timm 0.2.1+ required for `features_only` option in create_model
timm>=0.2.1
torchmetrics
# Optional developer requirements
style =
# black 21+ required for Python 3.9 support
Expand All @@ -92,6 +90,7 @@ style =
isort[colors]>=5.8
# pydocstyle 6.1+ required for pyproject.toml support
pydocstyle[toml]>=6.1
# Optional testing requirements
tests =
# mypy 0.900+ required for pyproject.toml support
mypy>=0.900
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class TestADVANCE:
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> ADVANCE:
pytest.importorskip("scipy", minversion="0.9.0")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
Expand Down Expand Up @@ -57,6 +56,7 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
)

def test_getitem(self, dataset: ADVANCE) -> None:
pytest.importorskip("scipy", minversion="0.9.0")
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
Expand Down
30 changes: 28 additions & 2 deletions tests/datasets/test_nwpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
import sys
from pathlib import Path
from typing import Generator
from typing import Any, Generator

import pytest
import torch
Expand All @@ -17,7 +18,6 @@
import torchgeo.datasets.utils
from torchgeo.datasets import VHR10

pytest.importorskip("rarfile")
pytest.importorskip("pycocotools")


Expand All @@ -35,6 +35,7 @@ def dataset(
tmp_path: Path,
request: SubRequest,
) -> VHR10:
pytest.importorskip("rarfile", minversion="3")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.nwpu, "download_url", download_url
)
Expand All @@ -54,6 +55,21 @@ def dataset(
transforms = nn.Identity() # type: ignore[attr-defined]
return VHR10(root, split, transforms, download=True, checksum=True)

@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "pycocotools.coco":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)

def test_getitem(self, dataset: VHR10) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down Expand Up @@ -84,3 +100,13 @@ def test_invalid_split(self) -> None:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
VHR10(str(tmp_path))

def test_mock_missing_module(
self, dataset: VHR10, mock_missing_module: None
) -> None:
if dataset.split == "positive":
with pytest.raises(
ImportError,
match="pycocotools is not installed and is required to use this datase",
):
VHR10(dataset.root, dataset.split)
4 changes: 2 additions & 2 deletions tests/datasets/test_resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45, RESISC45DataModule

pytest.importorskip("rarfile")


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
Expand All @@ -32,6 +30,8 @@ def dataset(
tmp_path: Path,
request: SubRequest,
) -> RESISC45:
pytest.importorskip("rarfile", minversion="3")

monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.resisc45, "download_url", download_url
)
Expand Down
20 changes: 19 additions & 1 deletion tests/datasets/test_so2sat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
from pathlib import Path
from typing import Generator
from typing import Any, Generator

import pytest
import torch
Expand All @@ -13,6 +14,8 @@

from torchgeo.datasets import So2Sat, So2SatDataModule

pytest.importorskip("h5py")


class TestSo2Sat:
@pytest.fixture(params=["train", "validation", "test"])
Expand All @@ -31,6 +34,21 @@ def dataset(
transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True)

@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)

def test_getitem(self, dataset: So2Sat) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_mock_missing_module(mock_missing_module: None) -> None:
],
)
def test_extract_archive(src: str, tmp_path: Path) -> None:
pytest.importorskip("rarfile")
pytest.importorskip("rarfile", minversion="3")
extract_archive(os.path.join("tests", "data", src), str(tmp_path))


Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/nwpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def __init__(

if split == "positive":
# Must be installed to parse annotations file
from pycocotools.coco import COCO
try:
from pycocotools.coco import COCO # noqa: F401
except ImportError:
raise ImportError(
"pycocotools is not installed and is required to use this dataset"
)

self.coco = COCO(
os.path.join(
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def __init__(
AssertionError: if ``split`` argument is invalid
RuntimeError: if data is not found in ``root``, or checksums don't match
"""
import h5py
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)

assert split in ["train", "validation", "test"]

Expand Down

0 comments on commit 01ae2db

Please sign in to comment.