Skip to content

Commit

Permalink
Merge pull request #184 from NOAA-OWP/stac_dataframe
Browse files Browse the repository at this point in the history
STAC Catalog Functionality
  • Loading branch information
fernando-aristizabal authored Mar 27, 2024
2 parents 2eeb8e1 + 1e9c1db commit 9d5bc26
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 470 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
requires-python = ">=3.8"
keywords = ["geospatial", "evaluations"]
license = {text = "MIT"}
version = "0.2.5"
version = "0.2.6"
dynamic = ["readme", "dependencies"]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ flox==0.7.2
xskillscore==0.0.24
pyogrio==0.7.2
pystac-client==0.7.5
stackstac==0.5.0
s3fs<=2023.12.1
9 changes: 6 additions & 3 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,18 @@ def __handle_attribute_tracking(
else:
del attribute_tracking_kwargs["agreement_map"]

results = candidate_map.gval.attribute_tracking_xarray(
results = _attribute_tracking_xarray(
candidate_map=candidate_map,
benchmark_map=benchmark_map,
agreement_map=agreement_map,
**attribute_tracking_kwargs,
)

else:
results = candidate_map.gval.attribute_tracking_xarray(
benchmark_map=benchmark_map, agreement_map=agreement_map
results = _attribute_tracking_xarray(
candidate_map=candidate_map,
benchmark_map=benchmark_map,
agreement_map=agreement_map,
)

return results
Expand Down
9 changes: 9 additions & 0 deletions src/gval/comparison/tabulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ def _crosstab_Datasets(agreement_map: xr.DataArray) -> DataFrame[Crosstab_df]:
# loop variables
previous_crosstab_df = None # initializing to avoid having unset
for i, b in enumerate(agreement_variable_names):
# Pass pairing dictionary to variable if necessary
if (
agreement_map[b].attrs.get("pairing_dictionary") is None
and agreement_map.attrs.get("pairing_dictionary") is not None
):
agreement_map[b].attrs["pairing_dictionary"] = agreement_map.attrs[
"pairing_dictionary"
]

crosstab_df = _crosstab_2d_DataArrays(
agreement_map=agreement_map[b], band_value=b
)
Expand Down
259 changes: 78 additions & 181 deletions src/gval/utils/loading_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
__author__ = "Fernando Aristizabal"


import warnings
from typing import Union, Optional, Tuple, Iterable
from numbers import Number
import ast
from collections import Counter

import pandas as pd
import rioxarray as rxr
import xarray as xr
import numpy as np
from tempfile import NamedTemporaryFile
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
import pystac_client

import stackstac
from pystac.item_collection import ItemCollection

_MEMORY_STRATEGY = "normal"

Expand Down Expand Up @@ -238,205 +237,103 @@ def _convert_to_dataset(xr_object=Union[xr.DataArray, xr.Dataset]) -> xr.Dataset
return xr_object


def _get_raster_band_nodata(band_metadata, nodata_fill) -> Number:
"""
Extracts nodata information from STAC APIs that implement Raster Extension
def stac_to_df(
stac_items: ItemCollection,
assets: list = None,
attribute_allow_list: list = None,
attribute_block_list: list = None,
) -> pd.DataFrame:
"""Convert STAC Items in to a DataFrame
Parameters
----------
band_metadata : list
Metadata from raster:bands extension
nodata_fill : Number
Fill in value for missing data
stac_items: ItemCollection
STAC Item Collection returned from pystac client
assets : list, default = None
Assets to keep, (keep all if None)
attribute_allow_list: list, default = None
List of columns to allow in the result DataFrame
attribute_block_list: list, default = None
List of columns to remove in the result DataFrame
Returns
-------
Number
Number representing nodata
pd.DataFrame
A DataFrame with rows for each unique item/asset combination
Raises
------
ValueError
Allow and block lists should be mutually exclusive
ValueError
No entries in DataFrame due to nonexistent asset
ValueError
There are no assets in this query to run a catalog comparison
"""

if band_metadata:
prop_string = str(band_metadata.coords["raster:bands"].values)
idx1, idx2 = prop_string.find("{"), prop_string.rfind("}")
item_dfs, compare_idx = [], 1

return ast.literal_eval(prop_string[idx1 : idx2 + 1]).get("nodata")
else:
if nodata_fill is None:
raise ValueError(
"Must have nodata fill value if nodata is not present in metadata"
# Check for mutually exclusive lists
if (
len(
list(
(
Counter(attribute_allow_list) & Counter(attribute_block_list)
).elements()
)

return nodata_fill


def _set_nodata(
stack: xr.DataArray, band_metadata: list = None, nodata_fill: Number = None
) -> Number:
"""
Sets nodata information from STAC APIs that implement Raster Extension
Parameters
----------
stack : xr.DataArray
Data to set nodata attribute
band_metadata : list
Metadata from raster:bands extension
nodata_fill : Number
Fill in value for missing data
"""

if stack.rio.nodata is not None:
stack.rio.write_nodata(stack.rio.nodata, inplace=True)
elif stack.rio.encoded_nodata is not None:
stack.rio.write_nodata(stack.rio.encoded_nodata, inplace=True)
else:
stack.rio.write_nodata(
_get_raster_band_nodata(band_metadata, nodata_fill), inplace=True
)
> 0
):
raise ValueError(
"There are no assets in this query to run a catalog comparison"
)

# Iterate through each STAC Item and make a unique row for each asset
for item in stac_items:
item_dict = item.to_dict()
item_df = pd.json_normalize(item_dict)
mask = item_df.columns.str.contains("assets.*")
og_df = item_df.loc[:, ~mask]

if (
assets is not None
and np.sum([asset not in item_dict["assets"].keys() for asset in assets])
> 0
):
raise ValueError("Non existent asset in parameter assets")

dfs = []

# Make a unique row for each asset
for key, val in item_dict["assets"].items():
if assets is None or key in assets:
df = pd.json_normalize(val)
df["asset"] = key
df["compare_id"] = compare_idx
df["map_id"] = val["href"]
compare_idx += 1
concat_df = pd.concat([og_df, df], axis=1)
dfs.append(concat_df.loc[:, ~concat_df.columns.duplicated()])

if len(dfs) < 1:
raise ValueError(
"There are no assets in this query to run a catalog comparison. "
"Please revisit original query."
)

def _set_crs(stack: xr.DataArray, band_metadata: list = None) -> Number:
"""
Parameters
----------
stack : xr.DataArray
Original data with no information
band_metadata : dict
Information with band metadata
Returns
-------
Xarray DataArray with proper CRS
"""

if stack.rio.crs is not None:
return stack.rio.write_crs(stack.rio.crs)
else:
return stack.rio.write_crs(f"EPSG:{band_metadata['epsg'].values}")


def get_stac_data(
url: str,
collection: str,
time: str,
bands: list = None,
query: str = None,
time_aggregate: str = None,
max_items: int = None,
intersects: dict = None,
bbox: list = None,
resolution: int = None,
nodata_fill: Number = None,
) -> xr.Dataset:
"""
Parameters
----------
url : str
Address hosting the STAC API
collection : str
Name of collection to get (currently limited to one)
time : str
Single or range of values to query in the time dimension
bands: list, default = None
Bands to retrieve from service
query : str, default = None
String command to filter data
time_aggregate : str, default = None
Method to aggregate multiple time stamps
max_items : int, default = None
The maximum amount of records to retrieve
intersects : dict, default = None
Dictionary representing the type of geometry and its respective coordinates
bbox : list, default = None
Coordinates to filter the spatial range of request
resolution : int, default = 10
Resolution to get data from
nodata_fill : Number, default = None
Value to fill nodata where not present
item_dfs.append(pd.concat(dfs, ignore_index=True))

Returns
-------
xr.Dataset
Xarray object with resepective STAC API data
# Concatenate the DataFrames and remove unwanted columns if allow and block lists exist
catalog_df = pd.concat(item_dfs, ignore_index=True)

"""
if attribute_allow_list is not None:
catalog_df = catalog_df[attribute_allow_list]

with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Call cataloging url, search, and convert to xarray
catalog = pystac_client.Client.open(url)

stac_items = catalog.search(
datetime=time,
collections=[collection],
max_items=max_items,
intersects=intersects,
bbox=bbox,
query=query,
).get_all_items()

stack = stackstac.stack(stac_items, resolution=resolution)

# Only get unique time indices in case there are duplicates
_, idxs = np.unique(stack.coords["time"], return_index=True)
stack = stack[idxs]

# Aggregate if there is more than one time
if stack.coords["time"].shape[0] > 1:
crs = stack.rio.crs
if time_aggregate == "mean":
stack = stack.mean(dim="time")
stack.attrs["time_aggregate"] = "mean"
elif time_aggregate == "min":
stack = stack.min(dim="time")
stack.attrs["time_aggregate"] = "min"
elif time_aggregate == "max":
stack = stack.max(dim="time")
stack.attrs["time_aggregate"] = "max"
else:
raise ValueError("A valid aggregate must be used for time ranges")

stack.rio.write_crs(crs, inplace=True)
else:
stack = stack[0]
stack.attrs["time_aggregate"] = "none"

# Select specific bands
if bands is not None:
bands = [bands] if isinstance(bands, str) else bands
stack = stack.sel({"band": bands})

band_metadata = (
stack.coords["raster:bands"] if "raster:bands" in stack.coords else None
)
if "band" in stack.dims:
og_names = [name for name in stack.coords["band"]]
names = [f"band_{x + 1}" for x in range(len(stack.coords["band"]))]
stack = stack.assign_coords({"band": names}).to_dataset(dim="band")

for metadata, var, og_var in zip(band_metadata, stack.data_vars, og_names):
_set_nodata(stack[var], metadata, nodata_fill)
stack[var] = _set_crs(stack[var], band_metadata)
stack[var].attrs["original_name"] = og_var

else:
stack = stack.to_dataset(name="band_1")
_set_nodata(stack["band_1"], band_metadata, nodata_fill)
stack["band_1"] = _set_crs(stack["band_1"])
stack["band_1"].attrs["original_name"] = (
bands[0] if isinstance(bands, list) else bands
)
if attribute_block_list is not None:
catalog_df = catalog_df.drop(attribute_block_list, axis=1)

return stack
return catalog_df


def _create_circle_mask(
Expand Down
Loading

0 comments on commit 9d5bc26

Please sign in to comment.