diff --git a/openeo/extra/spectral_indices/spectral_indices.py b/openeo/extra/spectral_indices/spectral_indices.py index cd2abfc75..2068086bf 100644 --- a/openeo/extra/spectral_indices/spectral_indices.py +++ b/openeo/extra/spectral_indices/spectral_indices.py @@ -1,7 +1,8 @@ +import functools import json +import re from typing import Dict, List, Optional -import numpy as np from openeo.processes import ProcessBuilder, array_create, array_modify from openeo.rest.datacube import DataCube @@ -11,56 +12,13 @@ except ImportError: import importlib.resources as importlib_resources - -BAND_MAPPING_LANDSAT457 = { - "B1": "B", - "B2": "G", - "B3": "R", - "B4": "N", - "B5": "S1", - "B6": "T1", - "B7": "S2", -} -BAND_MAPPING_LANDSAT8 = { - "B1": "A", - "B2": "B", - "B3": "G", - "B4": "R", - "B5": "N", - "B6": "S1", - "B7": "S2", - "B10": "T1", - "B11": "T2", -} -BAND_MAPPING_MODIS = { - "B3": "B", - "B4": "G", - "B1": "R", - "B2": "N", - "B5": np.nan, - "B6": "S1", - "B7": "S2", -} +# TODO BAND_MAPPING_PROBAV = { "BLUE": "B", "RED": "R", "NIR": "N", "SWIR": "S1", } -BAND_MAPPING_SENTINEL2 = { - "B1": "A", - "B2": "B", - "B3": "G", - "B4": "R", - "B5": "RE1", - "B6": "RE2", - "B7": "RE3", - "B8": "N", - "B8A": "RE4", - "B9": "WV", - "B11": "S1", - "B12": "S2", -} BAND_MAPPING_SENTINEL1 = { "HH": "HH", "HV": "HV", @@ -69,32 +27,9 @@ } -def _get_expression_map(x: ProcessBuilder, platform: str, band_names: List[str]) -> Dict[str, ProcessBuilder]: - """Build mapping of ASI formula variable names to `array_element` nodes.""" - platform = platform.upper() - # TODO: See if we can use common band names from collections instead of hardcoded mapping - if "LANDSAT8" in platform: - band_mapping = BAND_MAPPING_LANDSAT8 - elif "LANDSAT" in platform: - band_mapping = BAND_MAPPING_LANDSAT457 - elif "MODIS" in platform: - band_mapping = BAND_MAPPING_MODIS - elif "PROBAV" in platform: - band_mapping = BAND_MAPPING_PROBAV - elif "TERRASCOPE_S2" in platform or "SENTINEL2" in platform: - band_mapping = BAND_MAPPING_SENTINEL2 - elif "SENTINEL1" in platform: - band_mapping = BAND_MAPPING_SENTINEL1 - else: - # TODO: better error message: provide options? - raise ValueError(f"Unknown satellite platform {platform!r} (to determine band name mapping)") - - # TODO: get rid of this ugly "0" replace hack (looks like asking for trouble) - cube_bands = [band.replace("0", "").upper() for band in band_names] - # TODO: use `label` parameter from `array_element` to avoid index based band references. - return {band_mapping[b]: x.array_element(i) for i, b in enumerate(cube_bands) if b in band_mapping} +@functools.lru_cache(maxsize=1) def load_indices() -> Dict[str, dict]: """Load set of supported spectral indices.""" specs = {} @@ -115,6 +50,7 @@ def load_indices() -> Dict[str, dict]: return specs +@functools.lru_cache(maxsize=1) def load_constants() -> Dict[str, float]: """Load constants defined by Awesome Spectral Indices.""" # TODO: encapsulate all this json loading in a single registry class? @@ -126,6 +62,65 @@ def load_constants() -> Dict[str, float]: return {k: v["default"] for k, v in data.items() if isinstance(v["default"], (int, float))} +@functools.lru_cache(maxsize=1) +def _load_bands() -> Dict[str, dict]: + """Load band name mapping defined by Awesome Spectral Indices.""" + with importlib_resources.files( + "openeo.extra.spectral_indices" + ) / "resources/awesome-spectral-indices/bands.json" as resource_path: + data = json.loads(resource_path.read_text(encoding="utf8")) + return data + + +class _BandMapping: + """ + Helper class to extract mappings between band names and variable names used in Awesome Spectral Indices formulas. + """ + + def __init__(self): + # Load bands.json from Awesome Spectral Indices + self._band_data = _load_bands() + + @staticmethod + def _normalize_platform(platform: str) -> str: + platform = platform.lower().replace("-", "").replace(" ", "") + if platform in {"sentinel2a", "sentinel2b"}: + platform = "sentinel2" + return platform + + @staticmethod + def _normalize_band_name(band_name: str) -> str: + band_name = band_name.upper() + # Normalize band names like "B01" to "B1" + band_name = re.sub(r"^B0+(\d+)$", r"B\1", band_name) + return band_name + + def variable_to_band_name_map(self, platform: str) -> Dict[str, str]: + """ + Build mapping from Awesome Spectral Indices variable names to (normalized) band names for given satellite platform. + """ + var_to_band = { + var: pf_data["band"] + for var, var_data in self._band_data.items() + for pf, pf_data in var_data.get("platforms", {}).items() + if self._normalize_platform(pf) == self._normalize_platform(platform) + } + if not var_to_band: + raise ValueError(f"Empty band mapping derived for satellite platform {platform!r}") + return var_to_band + + def actual_band_name_to_variable_map(self, platform: str, band_names: List[str]) -> Dict[str, str]: + """Build mapping from actual band names (as given) to Awesome Spectral Indices variable names.""" + var_to_band = self.variable_to_band_name_map(platform=platform) + band_to_var = { + band_name: var + for var, normalized_band_name in var_to_band.items() + for band_name in band_names + if self._normalize_band_name(band_name) == normalized_band_name + } + return band_to_var + + def list_indices() -> List[str]: """List names of supported spectral indices""" specs = load_indices() @@ -172,14 +167,26 @@ def _check_validity_index_dict(index_dict: dict, index_specs: dict): def _callback( - x: ProcessBuilder, index_dict: dict, index_specs: dict, append: bool, band_names: List[str], platform: str + x: ProcessBuilder, + index_dict: dict, + index_specs: dict, + append: bool, + band_names: List[str], + platform: str, + band_to_var: Optional[Dict[str, str]] = None, ) -> ProcessBuilder: index_values = [] x_res = x + if not band_to_var: + # TODO: move this outside of callback? + band_to_var = _BandMapping().actual_band_name_to_variable_map(platform=platform, band_names=band_names) + + vars = {band_to_var[bn]: x.array_element(i) for i, bn in enumerate(band_names) if bn in band_to_var} + eval_globals = { **load_constants(), - **_get_expression_map(x, band_names=band_names, platform=platform), + **vars, } # TODO: user might want to control order of indices, which is tricky through a dictionary. for index, params in index_dict["indices"].items(): diff --git a/tests/extra/spectral_indices/test_spectral_indices.py b/tests/extra/spectral_indices/test_spectral_indices.py index eaa0b3516..8b0759ffc 100644 --- a/tests/extra/spectral_indices/test_spectral_indices.py +++ b/tests/extra/spectral_indices/test_spectral_indices.py @@ -13,6 +13,7 @@ load_indices, load_constants, ) +from openeo.extra.spectral_indices.spectral_indices import _BandMapping from openeo.rest.datacube import DataCube @@ -43,6 +44,26 @@ def test_load_constants(): assert constants["g"] == 2.5 +class TestBandMapping: + def test_variable_to_band_map(self): + band_mapping = _BandMapping() + assert band_mapping.variable_to_band_name_map("modis") == { + "B": "B3", + "G": "B4", + "G1": "B11", + "N": "B2", + "R": "B1", + "S1": "B6", + "S2": "B7", + } + + def test_actual_band_name_to_variable_map(self): + band_mapping = _BandMapping() + assert band_mapping.actual_band_name_to_variable_map( + platform="sentinel2", band_names=["B02", "B03", "B04"] + ) == {"B02": "B", "B03": "G", "B04": "R"} + + def test_compute_and_rescale_indices(con): cube = con.load_collection("SENTINEL2") @@ -451,7 +472,7 @@ def test_compute_ndvi(con): @pytest.mark.parametrize("platform", ["Sentinel2", "SENTINEL2"]) def test_compute_ndvi_explicit_platform(con, platform): cube = con.load_collection("NELITENS2") - with pytest.raises(ValueError, match="Unknown satellite platform"): + with pytest.raises(ValueError, match="Empty band mapping derived for satellite platform 'NELITENS2'"): _ = compute_index(cube, index="NDVI") indices = compute_index(cube, index="NDVI", platform=platform)