From 949c2fb8fca54e85ed654bf85a7ca9c044cb5fb5 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Mon, 5 Sep 2022 19:34:16 +0500 Subject: [PATCH] move read_stac() to stac.py - add an extract() function to pipeline.file_system.utils that uses shutil.unpack_archive(). - add an item_limit arg to parse_stac() --- .../rastervision/core/utils/__init__.py | 4 ++ .../rastervision/core/utils/stac.py | 48 ++++++++++++++++++- .../pipeline/file_system/utils.py | 18 +++++++ .../pytorch_backend/examples/utils.py | 48 ++----------------- tests/core/utils/test_stac.py | 41 +++++++++++++++- tests/pytorch_backend/examples/test_utils.py | 46 ------------------ 6 files changed, 111 insertions(+), 94 deletions(-) delete mode 100644 tests/pytorch_backend/examples/test_utils.py diff --git a/rastervision_core/rastervision/core/utils/__init__.py b/rastervision_core/rastervision/core/utils/__init__.py index e69de29bb..fdc6d5d60 100644 --- a/rastervision_core/rastervision/core/utils/__init__.py +++ b/rastervision_core/rastervision/core/utils/__init__.py @@ -0,0 +1,4 @@ +# flake8: noqa + +from rastervision.core.utils.stac import * +from rastervision.core.utils.misc import * diff --git a/rastervision_core/rastervision/core/utils/stac.py b/rastervision_core/rastervision/core/utils/stac.py index b2f967705..e43dbbfbe 100644 --- a/rastervision_core/rastervision/core/utils/stac.py +++ b/rastervision_core/rastervision/core/utils/stac.py @@ -1,6 +1,7 @@ from typing import List, Optional from urllib.parse import urlparse import logging +from itertools import islice import boto3 from pystac import StacIO, Catalog, Item @@ -62,7 +63,7 @@ def get_linked_image_item(label_item: Item) -> Optional[Item]: return image_item -def parse_stac(stac_uri: str) -> List[dict]: +def parse_stac(stac_uri: str, item_limit: Optional[int] = None) -> List[dict]: """Parse a STAC catalog JSON file to extract label URIs, images URIs, and AOIs. @@ -88,7 +89,8 @@ def parse_stac(stac_uri: str) -> List[dict]: cat.make_all_asset_hrefs_absolute() - label_items = [item for item in cat.get_all_items() if is_label_item(item)] + label_items = list( + islice(filter(is_label_item, cat.get_all_items()), item_limit)) image_items = [get_linked_image_item(item) for item in label_items] if len(label_items) == 0: @@ -122,3 +124,45 @@ def parse_stac(stac_uri: str) -> List[dict]: 'aoi_geometry': aoi_geometry }) return out + + +def read_stac(uri: str, extract_dir: Optional[str] = None, + **kwargs) -> List[dict]: + """Parse the contents of a STAC catalog (downloading it first, if + remote). If the uri is a zip file, unzip it, find catalog.json inside it + and parse that. + + Args: + uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip + file containing a STAC catalog JSON file. + + Raises: + FileNotFoundError: If catalog.json is not found inside the zip file. + Exception: If multiple catalog.json's are found inside the zip file. + + Returns: + List[dict]: A lsit of dicts with keys: "label_uri", "image_uris", + "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry". + Each dict corresponds to one label item and its associated image + assets in the STAC catalog. + """ + from pathlib import Path + from rastervision.pipeline.file_system.utils import (download_if_needed, + is_archive, extract) + + catalog_path = download_if_needed(uri) + if catalog_path.lower().endswith('.json'): + return parse_stac(catalog_path, **kwargs) + + if not is_archive(catalog_path): + raise ValueError(f'Unsupported file format: ("{uri}"). ' + 'URIS must be a JSON file or compressed archive.') + + extract_dir = extract(catalog_path, extract_dir) + catalog_paths = list(Path(extract_dir).glob('**/catalog.json')) + if len(catalog_paths) == 0: + raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.') + elif len(catalog_paths) > 1: + raise Exception(f'More than one "catalog.json" found in ' f'{uri}.') + catalog_path = str(catalog_paths[0]) + return parse_stac(catalog_path, **kwargs) diff --git a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py index e3d5ea90f..739dc3440 100644 --- a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py +++ b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py @@ -327,3 +327,21 @@ def unzip(zip_path: str, target_dir: str): def is_local(uri: str) -> bool: return FileSystem.get_file_system(uri) == LocalFileSystem + + +def is_archive(uri: str) -> bool: + """Check if the URI's extension represents an archived file.""" + formats = sum((fmts for _, fmts, _ in shutil.get_unpack_formats()), []) + return any(uri.endswith(fmt) for fmt in formats) + + +def extract(uri: str, + target_dir: Optional[str] = None, + download_dir: Optional[str] = None) -> str: + """Extract a compressed file.""" + if target_dir is None: + target_dir = rv_config.get_cache_dir() + make_dir(target_dir) + local_path = download_if_needed(uri, download_dir) + shutil.unpack_archive(local_path, target_dir) + return target_dir diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py index a97666829..58cb4ecbe 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/utils.py @@ -1,8 +1,6 @@ -from typing import List, Optional import csv from io import StringIO import os -from pathlib import Path import rasterio from shapely.strtree import STRtree @@ -14,10 +12,9 @@ from rastervision.core.data import (RasterioCRSTransformer, GeoJSONVectorSourceConfig, ClassInferenceTransformerConfig) -from rastervision.core.utils.stac import parse_stac -from rastervision.pipeline.file_system import ( - file_to_str, file_exists, get_local_path, upload_or_copy, make_dir, - json_to_file, download_if_needed, unzip) +from rastervision.pipeline.file_system import (file_to_str, file_exists, + get_local_path, upload_or_copy, + make_dir, json_to_file) from rastervision.aws_s3 import S3FileSystem @@ -147,42 +144,3 @@ def p2m(x, y, z=None): finally: os.environ.clear() os.environ.update(old_environ) - - -def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]: - """Parse the contents of a STAC catalog (downloading it first, if - remote). If the uri is a zip file, unzip it, find catalog.json inside it - and parse that. - - Args: - uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip - file containing a STAC catalog JSON file. - - Raises: - FileNotFoundError: If catalog.json is not found inside the zip file. - Exception: If multiple catalog.json's are found inside the zip file. - - Returns: - List[dict]: A lsit of dicts with keys: "label_uri", "image_uris", - "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry". - Each dict corresponds to one label item and its associated image - assets in the STAC catalog. - """ - uri_path = Path(uri) - is_zip = uri_path.suffix.lower() == '.zip' - - catalog_path = download_if_needed(uri) - if not is_zip: - return parse_stac(catalog_path) - if unzip_dir is None: - raise ValueError( - f'uri ("{uri}") is a zip file, but no unzip_dir provided.') - zip_path = catalog_path - unzip(zip_path, target_dir=unzip_dir) - catalog_paths = list(Path(unzip_dir).glob('**/catalog.json')) - if len(catalog_paths) == 0: - raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.') - elif len(catalog_paths) > 1: - raise Exception(f'More than one "catalog.json" found in ' f'{uri}.') - catalog_path = str(catalog_paths[0]) - return parse_stac(catalog_path) diff --git a/tests/core/utils/test_stac.py b/tests/core/utils/test_stac.py index b57e0430e..6e5b19797 100644 --- a/tests/core/utils/test_stac.py +++ b/tests/core/utils/test_stac.py @@ -1,7 +1,13 @@ from typing import Callable +import os import unittest -from rastervision.core.utils.stac import setup_stac_io +from shapely.geometry import Polygon + +from rastervision.pipeline import rv_config +from rastervision.core.utils.stac import setup_stac_io, read_stac + +from tests import data_file_path class TestStac(unittest.TestCase): @@ -14,6 +20,39 @@ def assertNoError(self, fn: Callable, msg: str = ''): def test_setup_stac_io(self): self.assertNoError(setup_stac_io) + def test_read_stac(self): + zip_path = data_file_path('catalog.zip') + expected_keys = { + 'label_uri': str, + 'image_uris': list, + 'label_bbox': Polygon, + 'image_bbox': (type(None), Polygon), + 'bboxes_intersect': bool, + 'aoi_geometry': (type(None), dict) + } + + with rv_config.get_tmp_dir() as tmp_dir: + out = read_stac(zip_path, tmp_dir) + + # check for correctness of format + self.assertIsInstance(out, list) + for o in out: + self.assertIsInstance(o, dict) + for k, v in o.items(): + self.assertTrue(k in expected_keys) + self.assertIsInstance(v, expected_keys[k]) + for uri in o['image_uris']: + self.assertIsInstance(uri, str) + + # check for correctness of content (WRT the test catalog) + self.assertEqual(len(out), 1) + self.assertEqual(len(out[0]['image_uris']), 1) + + for o in out: + label_uri = o['label_uri'] + self.assertTrue(os.path.exists(label_uri)) + self.assertTrue(label_uri.startswith(tmp_dir)) + if __name__ == '__main__': unittest.main() diff --git a/tests/pytorch_backend/examples/test_utils.py b/tests/pytorch_backend/examples/test_utils.py deleted file mode 100644 index 6df6ab0ab..000000000 --- a/tests/pytorch_backend/examples/test_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -import unittest -from tempfile import TemporaryDirectory -import os.path - -from shapely.geometry import Polygon -from rastervision.pytorch_backend.examples.utils import read_stac - -from tests import data_file_path - - -class TestUtils(unittest.TestCase): - def test_read_stac(self): - expected_keys = { - 'label_uri': str, - 'image_uris': list, - 'label_bbox': Polygon, - 'image_bbox': (type(None), Polygon), - 'bboxes_intersect': bool, - 'aoi_geometry': (type(None), dict) - } - zip_path = data_file_path('catalog.zip') - with TemporaryDirectory(dir='/opt/data/tmp') as tmp_dir: - out = read_stac(zip_path, tmp_dir) - - # check for correctness of format - self.assertIsInstance(out, list) - for o in out: - self.assertIsInstance(o, dict) - for k, v in o.items(): - self.assertTrue(k in expected_keys) - self.assertIsInstance(v, expected_keys[k]) - for uri in o['image_uris']: - self.assertIsInstance(uri, str) - - # check for correctness of content (WRT the test catalog) - self.assertEqual(len(out), 1) - self.assertEqual(len(out[0]['image_uris']), 1) - - for o in out: - label_uri = o['label_uri'] - self.assertTrue(os.path.exists(label_uri)) - self.assertTrue(label_uri.startswith(tmp_dir)) - - -if __name__ == '__main__': - unittest.main()