From c703ea828f8fcd9df69a2e8d083b14618bc95f0c Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Wed, 27 Sep 2023 15:44:31 -0400 Subject: [PATCH] feat: add bands --- pystac/__init__.py | 2 ++ pystac/asset.py | 22 ++++++++++++++++++++ pystac/band.py | 45 ++++++++++++++++++++++++++++++++++++++++ pystac/collection.py | 26 +++++++++++++++++++++++ pystac/item.py | 26 +++++++++++++++++++++++ tests/test_collection.py | 30 +++++++++++++++++++++++++++ tests/test_item.py | 41 +++++++++++++++++++++++++++++++++++- 7 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 pystac/band.py diff --git a/pystac/__init__.py b/pystac/__init__.py index 6f3036811..c3649e955 100644 --- a/pystac/__init__.py +++ b/pystac/__init__.py @@ -21,6 +21,7 @@ "STACObjectType", "Link", "HIERARCHICAL_LINKS", + "Band", "Catalog", "CatalogType", "Collection", @@ -75,6 +76,7 @@ SpatialExtent, TemporalExtent, ) +from pystac.band import Band from pystac.common_metadata import CommonMetadata from pystac.summaries import RangeSummary, Summaries from pystac.asset import Asset diff --git a/pystac/asset.py b/pystac/asset.py index cbcf2c039..1315cef22 100644 --- a/pystac/asset.py +++ b/pystac/asset.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union from pystac import common_metadata, utils +from pystac.band import Band from pystac.html.jinja_env import get_jinja_env if TYPE_CHECKING: @@ -71,6 +72,7 @@ def __init__( description: Optional[str] = None, media_type: Optional[str] = None, roles: Optional[List[str]] = None, + bands: Optional[List[Band]] = None, extra_fields: Optional[Dict[str, Any]] = None, ) -> None: self.href = utils.make_posix_style(href) @@ -78,6 +80,7 @@ def __init__( self.description = description self.media_type = media_type self.roles = roles + self._bands = bands self.extra_fields = extra_fields or {} # The Item which owns this Asset. @@ -113,6 +116,16 @@ def get_absolute_href(self) -> Optional[str]: return utils.make_absolute_href(self.href, item_self) return None + @property + def bands(self) -> Optional[List[Band]]: + if self._bands is None and self.owner is not None: + return self.owner.bands + return self._bands + + @bands.setter + def bands(self, bands: Optional[List[Band]]) -> None: + self._bands = bands + def to_dict(self) -> Dict[str, Any]: """Returns this Asset as a dictionary. @@ -138,6 +151,9 @@ def to_dict(self) -> Dict[str, Any]: if self.roles is not None: d["roles"] = self.roles + if self.bands is not None: + d["bands"] = [band.to_dict() for band in self.bands] + return d def clone(self) -> Asset: @@ -201,6 +217,11 @@ def from_dict(cls: Type[A], d: Dict[str, Any]) -> A: title = d.pop("title", None) description = d.pop("description", None) roles = d.pop("roles", None) + bands = d.pop("bands", None) + if bands is None: + deserialized_bands = None + else: + deserialized_bands = [Band.from_dict(band) for band in bands] properties = None if any(d): properties = d @@ -211,6 +232,7 @@ def from_dict(cls: Type[A], d: Dict[str, Any]) -> A: title=title, description=description, roles=roles, + bands=deserialized_bands, extra_fields=properties, ) diff --git a/pystac/band.py b/pystac/band.py new file mode 100644 index 000000000..d7962db98 --- /dev/null +++ b/pystac/band.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class Band: + """A name and some properties that apply to a band (aka subasset).""" + + name: str + """The name of the band (e.g., "B01", "B8", "band2", "red"). + + This should be unique across all bands defined in the list of bands. This is + typically the name the data provider uses for the band. + """ + + description: Optional[str] = None + """Description to fully explain the band. + + CommonMark 0.29 syntax MAY be used for rich text representation. + """ + + properties: Dict[str, Any] = field(default_factory=dict) + """Other properties on the band.""" + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> Band: + """Creates a new band object from a dictionary.""" + try: + name = d.pop("name") + except KeyError: + raise ValueError("missing required field on band: name") + description = d.pop("description", None) + return Band(name=name, description=description, properties=d) + + def to_dict(self) -> Dict[str, Any]: + """Creates a dictionary from this band object.""" + d = { + "name": self.name, + } + if self.description is not None: + d["description"] = self.description + d.update(self.properties) + return d diff --git a/pystac/collection.py b/pystac/collection.py index 766bd27a8..e30b22acb 100644 --- a/pystac/collection.py +++ b/pystac/collection.py @@ -22,6 +22,7 @@ import pystac from pystac import CatalogType, STACError, STACObjectType from pystac.asset import Asset +from pystac.band import Band from pystac.catalog import Catalog from pystac.errors import DeprecatedWarning, ExtensionNotImplemented, STACTypeError from pystac.layout import HrefLayoutStrategy @@ -517,6 +518,8 @@ class Collection(Catalog): """Default file name that will be given to this STAC object in a canonical format.""" + _bands: Optional[List[Band]] + def __init__( self, id: str, @@ -532,6 +535,7 @@ def __init__( providers: Optional[List[Provider]] = None, summaries: Optional[Summaries] = None, assets: Optional[Dict[str, Asset]] = None, + bands: Optional[List[Band]] = None, ): super().__init__( id, @@ -555,6 +559,8 @@ def __init__( for k, asset in assets.items(): self.add_asset(k, asset) + self._bands = bands + def __repr__(self) -> str: return "".format(self.id) @@ -588,6 +594,9 @@ def to_dict( if any(self.assets): d["assets"] = {k: v.to_dict() for k, v in self.assets.items()} + if self.bands is not None: + d["bands"] = [band.to_dict() for band in self.bands] + return d def clone(self) -> Collection: @@ -664,6 +673,12 @@ def from_dict( assets = {k: Asset.from_dict(v) for k, v in assets.items()} links = d.pop("links") + bands = d.pop("bands", None) + if bands is not None: + deserialized_bands = [Band.from_dict(band) for band in bands] + else: + deserialized_bands = None + d.pop("stac_version") collection = cls( @@ -680,6 +695,7 @@ def from_dict( href=href, catalog_type=catalog_type, assets=assets, + bands=deserialized_bands, ) for link in links: @@ -830,3 +846,13 @@ def full_copy( @classmethod def matches_object_type(cls, d: Dict[str, Any]) -> bool: return identify_stac_object_type(d) == STACObjectType.COLLECTION + + @property + def bands(self) -> Optional[List[Band]]: + """Returns the bands set on this collection.""" + return self._bands + + @bands.setter + def bands(self, bands: Optional[List[Band]]) -> None: + """Sets the bands on this collection.""" + self._bands = bands diff --git a/pystac/item.py b/pystac/item.py index 4ef676b7f..4bc1ee293 100644 --- a/pystac/item.py +++ b/pystac/item.py @@ -7,6 +7,7 @@ import pystac from pystac import RelType, STACError, STACObjectType from pystac.asset import Asset +from pystac.band import Band from pystac.catalog import Catalog from pystac.collection import Collection from pystac.errors import DeprecatedWarning, ExtensionNotImplemented @@ -106,6 +107,8 @@ class Item(STACObject): stac_extensions: List[str] """List of extensions the Item implements.""" + _bands: Optional[List[Band]] + STAC_OBJECT_TYPE = STACObjectType.ITEM def __init__( @@ -122,6 +125,7 @@ def __init__( collection: Optional[Union[str, Collection]] = None, extra_fields: Optional[Dict[str, Any]] = None, assets: Optional[Dict[str, Asset]] = None, + bands: Optional[List[Band]] = None, ): super().__init__(stac_extensions or []) @@ -167,6 +171,8 @@ def __init__( for k, asset in assets.items(): self.add_asset(k, asset) + self._bands = bands + def __repr__(self) -> str: return "".format(self.id) @@ -406,6 +412,16 @@ def get_derived_from(self) -> List[Item]: "Link failed to resolve. Use get_links instead." ) from e + @property + def bands(self) -> Optional[List[Band]]: + """Returns the bands set on this item.""" + return self._bands + + @bands.setter + def bands(self, bands: Optional[List[Band]]) -> None: + """Sets the bands on this item.""" + self._bands = bands + def to_dict( self, include_self_link: bool = True, transform_hrefs: bool = True ) -> Dict[str, Any]: @@ -442,6 +458,9 @@ def to_dict( for key in self.extra_fields: d[key] = self.extra_fields[key] + if self.bands is not None: + d["properties"]["bands"] = [band.to_dict() for band in self.bands] + return d def clone(self) -> Item: @@ -516,6 +535,12 @@ def from_dict( if k not in [*pass_through_fields, *parse_fields, *exclude_fields] } + bands = properties.pop("bands", None) + if bands is not None: + deserialized_bands = [Band.from_dict(d) for d in bands] + else: + deserialized_bands = None + item = cls( **{k: d.get(k) for k in pass_through_fields}, # type: ignore datetime=datetime, @@ -523,6 +548,7 @@ def from_dict( extra_fields=extra_fields, href=href, assets={k: Asset.from_dict(v) for k, v in assets.items()}, + bands=deserialized_bands, ) for link in links: diff --git a/tests/test_collection.py b/tests/test_collection.py index cb193cfdc..79fa5f104 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -14,6 +14,7 @@ import pystac from pystac import ( Asset, + Band, Catalog, CatalogType, Collection, @@ -670,3 +671,32 @@ def test_permissive_temporal_extent_deserialization(collection: Collection) -> N ]["interval"][0] with pytest.warns(UserWarning): Collection.from_dict(collection_dict) + + +def test_set_bands_on_collection(collection: Collection) -> None: + collection.add_asset("data", Asset(href="example.tif")) + collection.bands = [Band(name="analytic")] + assert collection.assets["data"].bands + assert collection.assets["data"].bands[0].name == "analytic" + + +def test_bands_roundtrip_on_asset(collection: Collection) -> None: + collection.add_asset("data", Asset(href="example.tif")) + collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False) + collection_dict["assets"]["data"]["bands"] = [{"name": "data"}] + collection = Collection.from_dict(collection_dict) + assert collection.assets["data"].bands + assert collection.assets["data"].bands[0].name == "data" + collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False) + assert collection_dict["assets"]["data"]["bands"][0]["name"] == "data" + + +def test_bands_roundtrip_on_collection(collection: Collection) -> None: + collection.add_asset("data", Asset(href="example.tif")) + collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False) + collection_dict["bands"] = [{"name": "data"}] + collection = Collection.from_dict(collection_dict) + assert collection.assets["data"].bands + assert collection.assets["data"].bands[0].name == "data" + collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False) + assert collection_dict["bands"][0]["name"] == "data" diff --git a/tests/test_item.py b/tests/test_item.py index 97ec3e956..01c2f3960 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -13,7 +13,7 @@ import pystac import pystac.serialization.common_properties -from pystac import Asset, Catalog, Collection, Item, Link +from pystac import Asset, Band, Catalog, Collection, Item, Link from pystac.utils import ( datetime_to_str, get_opt, @@ -636,3 +636,42 @@ def test_pathlib() -> None: # This works, but breaks mypy until we fix # https://github.com/stac-utils/pystac/issues/1216 Item.from_file(Path(TestCases.get_path("data-files/item/sample-item.json"))) + + +def test_bands_do_not_exist(sample_item: Item) -> None: + sample_item.assets["analytic"].bands is None + + +def test_set_bands(sample_item: Item) -> None: + sample_item.assets["analytic"].bands = [Band(name="analytic")] + assert sample_item.assets["analytic"].bands[0].name == "analytic" + + +def test_set_bands_on_item(sample_item: Item) -> None: + sample_item.bands = [Band(name="analytic")] + assert sample_item.assets["analytic"].bands + assert sample_item.assets["analytic"].bands[0].name == "analytic" + + +def test_bands_roundtrip_on_asset(sample_item: Item) -> None: + sample_item_dict = sample_item.to_dict( + include_self_link=False, transform_hrefs=False + ) + sample_item_dict["assets"]["analytic"]["bands"] = [{"name": "analytic"}] + item = Item.from_dict(sample_item_dict) + assert item.assets["analytic"].bands + assert item.assets["analytic"].bands[0].name == "analytic" + item_dict = item.to_dict(include_self_link=False, transform_hrefs=False) + assert item_dict["assets"]["analytic"]["bands"][0]["name"] == "analytic" + + +def test_bands_roundtrip_on_item(sample_item: Item) -> None: + sample_item_dict = sample_item.to_dict( + include_self_link=False, transform_hrefs=False + ) + sample_item_dict["properties"]["bands"] = [{"name": "analytic"}] + item = Item.from_dict(sample_item_dict) + assert item.assets["analytic"].bands + assert item.assets["analytic"].bands[0].name == "analytic" + item_dict = item.to_dict(include_self_link=False, transform_hrefs=False) + assert item_dict["properties"]["bands"][0]["name"] == "analytic"