Skip to content

Commit

Permalink
Using _stac_extensions to denote non-standard field in assets
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinius committed Oct 12, 2023
1 parent 99ea6aa commit 765a05f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 32 deletions.
5 changes: 3 additions & 2 deletions pystac/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Asset:
"""Optional, additional fields for this asset. This is used by extensions as a
way to serialize and deserialize properties on asset object JSON."""

stac_extensions: List[str]
_stac_extensions: List[str]
"""A list of schema URIs for STAC Extensions implemented by this STAC Asset."""

def __init__(
Expand All @@ -77,14 +77,15 @@ def __init__(
media_type: Optional[str] = None,
roles: Optional[List[str]] = None,
extra_fields: Optional[Dict[str, Any]] = None,
stac_extensions: Optional[List[str]] = None,
) -> None:
self.href = utils.make_posix_style(href)
self.title = title
self.description = description
self.media_type = media_type
self.roles = roles
self.extra_fields = extra_fields or {}
self.stac_extensions = None
self._stac_extensions = stac_extensions

# The Item which owns this Asset.
self.owner = None
Expand Down
8 changes: 4 additions & 4 deletions pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ def to_dict(
stac_extensions = dict.fromkeys(self.stac_extensions)

for asset in self.assets.values():
if stac_extensions and asset.stac_extensions:
stac_extensions.update(dict.fromkeys(asset.stac_extensions))
elif asset.stac_extensions:
stac_extensions = dict.fromkeys(asset.stac_extensions)
if stac_extensions and asset._stac_extensions:
stac_extensions.update(dict.fromkeys(asset._stac_extensions))
elif asset._stac_extensions:
stac_extensions = dict.fromkeys(asset._stac_extensions)

if stac_extensions is not None:
d["stac_extensions"] = list(stac_extensions.keys())
Expand Down
53 changes: 31 additions & 22 deletions pystac/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,7 @@ def _set_property(
self.properties[prop_name] = v


class STACExtendable(Protocol):
stac_extensions: List[str]


S = TypeVar("S", bound=STACExtendable)
S = TypeVar("S", bound=Union[pystac.STACObject, pystac.Asset])


class ExtensionManagementMixin(Generic[S], ABC):
Expand Down Expand Up @@ -134,19 +130,31 @@ def add_to(cls, obj: S) -> None:
"""Add the schema URI for this extension to the
:attr:`~pystac.STACObject.stac_extensions` list for the given object, if it is
not already present."""
if obj.stac_extensions is None:
obj.stac_extensions = [cls.get_schema_uri()]
elif not cls.has_extension(obj):
obj.stac_extensions.append(cls.get_schema_uri())
if isinstance(obj, pystac.Asset):
if obj._stac_extensions is None:
obj._stac_extensions = [cls.get_schema_uri()]
elif not cls.has_extension(obj):
obj._stac_extensions.append(cls.get_schema_uri())
else:
if obj.stac_extensions is None:
obj.stac_extensions = [cls.get_schema_uri()]
elif not cls.has_extension(obj):
obj.stac_extensions.append(cls.get_schema_uri())

@classmethod
def remove_from(cls, obj: S) -> None:
"""Remove the schema URI for this extension from the
:attr:`pystac.STACObject.stac_extensions` list for the given object."""

if obj.stac_extensions is not None:
obj.stac_extensions = [
uri for uri in obj.stac_extensions if uri != cls.get_schema_uri()
]
if isinstance(obj, pystac.Asset):
obj._stac_extensions = [
uri for uri in obj._stac_extensions if uri != cls.get_schema_uri()
]
else:
obj.stac_extensions = [
uri for uri in obj.stac_extensions if uri != cls.get_schema_uri()
]

@classmethod
def has_extension(cls, obj: S) -> bool:
Expand All @@ -156,9 +164,9 @@ def has_extension(cls, obj: S) -> bool:

if isinstance(obj, (pystac.Item, pystac.Collection)):
for asset in obj.assets.values():
if asset.stac_extensions is not None and any(
if asset._stac_extensions is not None and any(
uri.startswith(schema_startswith)
for uri in asset.stac_extensions
for uri in asset._stac_extensions
):
return True

Expand All @@ -168,6 +176,11 @@ def has_extension(cls, obj: S) -> bool:
for uri in obj.owner.stac_extensions
):
return True
else:
return obj._stac_extensions is not None and any(
uri.startswith(schema_startswith)
for uri in obj._stac_extensions
)

return obj.stac_extensions is not None and any(
uri.startswith(schema_startswith) for uri in obj.stac_extensions
Expand Down Expand Up @@ -239,15 +252,11 @@ def ensure_has_extension(cls, obj: S, add_if_missing: bool = False) -> None:
if add_if_missing:
cls.add_to(obj)

if isinstance(obj, pystac.Asset):
cls.ensure_has_extension(obj.owner)

if not cls.has_extension(obj):
if not obj.owner or not cls.has_extension(obj.owner):
raise pystac.ExtensionNotImplemented(
f"Could not find extension schema URI {cls.get_schema_uri()} "
"in object."
)
raise pystac.ExtensionNotImplemented(
f"Could not find extension schema URI {cls.get_schema_uri()} "
"in object."
)

@classmethod
def _ext_error_message(cls, obj: Any) -> str:
Expand Down
8 changes: 4 additions & 4 deletions pystac/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,10 @@ def to_dict(
stac_extensions = dict.fromkeys(self.stac_extensions)

for asset in self.assets.values():
if stac_extensions and asset.stac_extensions:
stac_extensions.update(dict.fromkeys(asset.stac_extensions))
elif asset.stac_extensions:
stac_extensions = dict.fromkeys(asset.stac_extensions)
if stac_extensions and asset._stac_extensions:
stac_extensions.update(dict.fromkeys(asset._stac_extensions))
elif asset._stac_extensions:
stac_extensions = dict.fromkeys(asset._stac_extensions)

if stac_extensions is not None:
d["stac_extensions"] = list(stac_extensions.keys())
Expand Down

0 comments on commit 765a05f

Please sign in to comment.