Skip to content

Commit

Permalink
move read_stac() to stac.py
Browse files Browse the repository at this point in the history
- add an extract() function to pipeline.file_system.utils that uses shutil.unpack_archive().
- add an item_limit arg to parse_stac()
  • Loading branch information
AdeelH committed Sep 19, 2022
1 parent 8aa4808 commit 949c2fb
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 94 deletions.
4 changes: 4 additions & 0 deletions rastervision_core/rastervision/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# flake8: noqa

from rastervision.core.utils.stac import *
from rastervision.core.utils.misc import *
48 changes: 46 additions & 2 deletions rastervision_core/rastervision/core/utils/stac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional
from urllib.parse import urlparse
import logging
from itertools import islice

import boto3
from pystac import StacIO, Catalog, Item
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_linked_image_item(label_item: Item) -> Optional[Item]:
return image_item


def parse_stac(stac_uri: str) -> List[dict]:
def parse_stac(stac_uri: str, item_limit: Optional[int] = None) -> List[dict]:
"""Parse a STAC catalog JSON file to extract label URIs, images URIs,
and AOIs.
Expand All @@ -88,7 +89,8 @@ def parse_stac(stac_uri: str) -> List[dict]:

cat.make_all_asset_hrefs_absolute()

label_items = [item for item in cat.get_all_items() if is_label_item(item)]
label_items = list(
islice(filter(is_label_item, cat.get_all_items()), item_limit))
image_items = [get_linked_image_item(item) for item in label_items]

if len(label_items) == 0:
Expand Down Expand Up @@ -122,3 +124,45 @@ def parse_stac(stac_uri: str) -> List[dict]:
'aoi_geometry': aoi_geometry
})
return out


def read_stac(uri: str, extract_dir: Optional[str] = None,
**kwargs) -> List[dict]:
"""Parse the contents of a STAC catalog (downloading it first, if
remote). If the uri is a zip file, unzip it, find catalog.json inside it
and parse that.
Args:
uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip
file containing a STAC catalog JSON file.
Raises:
FileNotFoundError: If catalog.json is not found inside the zip file.
Exception: If multiple catalog.json's are found inside the zip file.
Returns:
List[dict]: A lsit of dicts with keys: "label_uri", "image_uris",
"label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry".
Each dict corresponds to one label item and its associated image
assets in the STAC catalog.
"""
from pathlib import Path
from rastervision.pipeline.file_system.utils import (download_if_needed,
is_archive, extract)

catalog_path = download_if_needed(uri)
if catalog_path.lower().endswith('.json'):
return parse_stac(catalog_path, **kwargs)

if not is_archive(catalog_path):
raise ValueError(f'Unsupported file format: ("{uri}"). '
'URIS must be a JSON file or compressed archive.')

extract_dir = extract(catalog_path, extract_dir)
catalog_paths = list(Path(extract_dir).glob('**/catalog.json'))
if len(catalog_paths) == 0:
raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.')
elif len(catalog_paths) > 1:
raise Exception(f'More than one "catalog.json" found in ' f'{uri}.')
catalog_path = str(catalog_paths[0])
return parse_stac(catalog_path, **kwargs)
18 changes: 18 additions & 0 deletions rastervision_pipeline/rastervision/pipeline/file_system/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,21 @@ def unzip(zip_path: str, target_dir: str):

def is_local(uri: str) -> bool:
return FileSystem.get_file_system(uri) == LocalFileSystem


def is_archive(uri: str) -> bool:
"""Check if the URI's extension represents an archived file."""
formats = sum((fmts for _, fmts, _ in shutil.get_unpack_formats()), [])
return any(uri.endswith(fmt) for fmt in formats)


def extract(uri: str,
target_dir: Optional[str] = None,
download_dir: Optional[str] = None) -> str:
"""Extract a compressed file."""
if target_dir is None:
target_dir = rv_config.get_cache_dir()
make_dir(target_dir)
local_path = download_if_needed(uri, download_dir)
shutil.unpack_archive(local_path, target_dir)
return target_dir
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List, Optional
import csv
from io import StringIO
import os
from pathlib import Path

import rasterio
from shapely.strtree import STRtree
Expand All @@ -14,10 +12,9 @@
from rastervision.core.data import (RasterioCRSTransformer,
GeoJSONVectorSourceConfig,
ClassInferenceTransformerConfig)
from rastervision.core.utils.stac import parse_stac
from rastervision.pipeline.file_system import (
file_to_str, file_exists, get_local_path, upload_or_copy, make_dir,
json_to_file, download_if_needed, unzip)
from rastervision.pipeline.file_system import (file_to_str, file_exists,
get_local_path, upload_or_copy,
make_dir, json_to_file)
from rastervision.aws_s3 import S3FileSystem


Expand Down Expand Up @@ -147,42 +144,3 @@ def p2m(x, y, z=None):
finally:
os.environ.clear()
os.environ.update(old_environ)


def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]:
"""Parse the contents of a STAC catalog (downloading it first, if
remote). If the uri is a zip file, unzip it, find catalog.json inside it
and parse that.
Args:
uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip
file containing a STAC catalog JSON file.
Raises:
FileNotFoundError: If catalog.json is not found inside the zip file.
Exception: If multiple catalog.json's are found inside the zip file.
Returns:
List[dict]: A lsit of dicts with keys: "label_uri", "image_uris",
"label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry".
Each dict corresponds to one label item and its associated image
assets in the STAC catalog.
"""
uri_path = Path(uri)
is_zip = uri_path.suffix.lower() == '.zip'

catalog_path = download_if_needed(uri)
if not is_zip:
return parse_stac(catalog_path)
if unzip_dir is None:
raise ValueError(
f'uri ("{uri}") is a zip file, but no unzip_dir provided.')
zip_path = catalog_path
unzip(zip_path, target_dir=unzip_dir)
catalog_paths = list(Path(unzip_dir).glob('**/catalog.json'))
if len(catalog_paths) == 0:
raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.')
elif len(catalog_paths) > 1:
raise Exception(f'More than one "catalog.json" found in ' f'{uri}.')
catalog_path = str(catalog_paths[0])
return parse_stac(catalog_path)
41 changes: 40 additions & 1 deletion tests/core/utils/test_stac.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Callable
import os
import unittest

from rastervision.core.utils.stac import setup_stac_io
from shapely.geometry import Polygon

from rastervision.pipeline import rv_config
from rastervision.core.utils.stac import setup_stac_io, read_stac

from tests import data_file_path


class TestStac(unittest.TestCase):
Expand All @@ -14,6 +20,39 @@ def assertNoError(self, fn: Callable, msg: str = ''):
def test_setup_stac_io(self):
self.assertNoError(setup_stac_io)

def test_read_stac(self):
zip_path = data_file_path('catalog.zip')
expected_keys = {
'label_uri': str,
'image_uris': list,
'label_bbox': Polygon,
'image_bbox': (type(None), Polygon),
'bboxes_intersect': bool,
'aoi_geometry': (type(None), dict)
}

with rv_config.get_tmp_dir() as tmp_dir:
out = read_stac(zip_path, tmp_dir)

# check for correctness of format
self.assertIsInstance(out, list)
for o in out:
self.assertIsInstance(o, dict)
for k, v in o.items():
self.assertTrue(k in expected_keys)
self.assertIsInstance(v, expected_keys[k])
for uri in o['image_uris']:
self.assertIsInstance(uri, str)

# check for correctness of content (WRT the test catalog)
self.assertEqual(len(out), 1)
self.assertEqual(len(out[0]['image_uris']), 1)

for o in out:
label_uri = o['label_uri']
self.assertTrue(os.path.exists(label_uri))
self.assertTrue(label_uri.startswith(tmp_dir))


if __name__ == '__main__':
unittest.main()
46 changes: 0 additions & 46 deletions tests/pytorch_backend/examples/test_utils.py

This file was deleted.

0 comments on commit 949c2fb

Please sign in to comment.