diff --git a/openeogeotrellis/GeotrellisImageCollection.py b/openeogeotrellis/GeotrellisImageCollection.py index a96464ab9..4c9867787 100644 --- a/openeogeotrellis/GeotrellisImageCollection.py +++ b/openeogeotrellis/GeotrellisImageCollection.py @@ -33,6 +33,7 @@ def __init__(self, pyramid: Pyramid, service_registry: InMemoryServiceRegistry, self.pyramid = pyramid self.tms = None self._service_registry = service_registry + # TODO get rid of this _band_index stuff. See https://github.com/Open-EO/openeo-geopyspark-driver/issues/29 self._band_index = 0 def apply_to_levels(self, func): @@ -69,10 +70,12 @@ def create_tilelayer(contextrdd, layer_type, zoom_level): return GeotrellisTimeSeriesImageCollection(pyramid, self._service_registry, metadata=self.metadata)._with_band_index(self._band_index) def _with_band_index(self, band_index): + # TODO get rid of this _band_index stuff. See https://github.com/Open-EO/openeo-geopyspark-driver/issues/29 self._band_index = band_index return self def band_filter(self, bands) -> 'ImageCollection': + # TODO get rid of this _band_index stuff. See https://github.com/Open-EO/openeo-geopyspark-driver/issues/29 if isinstance(bands, int): self._band_index = bands elif isinstance(bands, list) and len(bands) == 1: diff --git a/openeogeotrellis/layercatalog.py b/openeogeotrellis/layercatalog.py index 6f7714626..f9ea28bab 100644 --- a/openeogeotrellis/layercatalog.py +++ b/openeogeotrellis/layercatalog.py @@ -57,6 +57,9 @@ def load_collection(self, collection_id: str, viewing_parameters: dict) -> Image srs = viewing_parameters.get("srs", None) bands = viewing_parameters.get("bands", []) band_indices = [metadata.get_band_index(b) for b in bands] + # TODO: avoid this `still_needs_band_filter` ugliness. + # Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29 + still_needs_band_filter = False pysc = gps.get_spark_context() extent = None @@ -69,13 +72,16 @@ def accumulo_pyramid(): pyramidFactory = jvm.org.openeo.geotrellisaccumulo.PyramidFactory("hdp-accumulo-instance", ','.join(ConfigParams().zookeepernodes)) accumulo_layer_name = layer_source_info['data_id'] + nonlocal still_needs_band_filter + still_needs_band_filter = bool(band_indices) return pyramidFactory.pyramid_seq(accumulo_layer_name, extent, srs, from_date, to_date) def s3_pyramid(): endpoint = layer_source_info['endpoint'] region = layer_source_info['region'] bucket_name = layer_source_info['bucket_name'] - + nonlocal still_needs_band_filter + still_needs_band_filter = bool(band_indices) return jvm.org.openeo.geotrelliss3.PyramidFactory(endpoint, region, bucket_name) \ .pyramid_seq(extent, srs, from_date, to_date) @@ -133,7 +139,13 @@ def sentinel_hub_l8_pyramid(): service_registry=self._service_registry, metadata=metadata ) - return image_collection.band_filter(band_indices) if band_indices else image_collection + + if still_needs_band_filter: + # TODO: avoid this `still_needs_band_filter` ugliness. + # Also see https://github.com/Open-EO/openeo-geopyspark-driver/issues/29 + image_collection = image_collection.band_filter(band_indices) + + return image_collection def get_layer_catalog(service_registry: InMemoryServiceRegistry = None) -> GeoPySparkLayerCatalog: