From c88779bd231eb452575f3274ad024a08aee49516 Mon Sep 17 00:00:00 2001 From: dfguerrerom Date: Fri, 20 Sep 2024 13:13:16 +0200 Subject: [PATCH] avoid gee calls in the threads --- component/scripts/export.py | 4 +- component/scripts/gee.py | 278 ++++++++++++++++++++------------- test.ipynb | 140 +++++++++++++++++ test/test.ipynb | 301 ++++++++++++++++++++++++++++++++++++ test/test_gee.py | 72 ++++----- ui.ipynb | 3 + 6 files changed, 656 insertions(+), 142 deletions(-) create mode 100644 test.ipynb create mode 100644 test/test.ipynb diff --git a/component/scripts/export.py b/component/scripts/export.py index 4c86d52..f16350f 100644 --- a/component/scripts/export.py +++ b/component/scripts/export.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.backends.backend_pdf import PdfPages -from pypdf import PdfMerger +from pypdf import PdfWriter from sepal_ui.scripts.utils import init_ee from unidecode import unidecode @@ -171,7 +171,7 @@ def get_pdf( # merge all the pdf files output.add_live_msg("merge all pdf files") - merger = PdfMerger() + merger = PdfWriter() for pdf in pdf_tmps: merger.append(pdf) merger.write(str(pdf_filepath)) diff --git a/component/scripts/gee.py b/component/scripts/gee.py index 0737460..040559b 100644 --- a/component/scripts/gee.py +++ b/component/scripts/gee.py @@ -1,24 +1,67 @@ import concurrent.futures import threading import zipfile -from functools import partial from pathlib import Path -from typing import Literal, Tuple +from typing import List, Literal, Tuple from urllib.request import urlretrieve import ee from osgeo import gdal -from sepal_ui import sepalwidgets as sw from sepal_ui.scripts.utils import init_ee from component import parameter as cp from component import widget as cw -from component.message import cm from .utils import get_buffers, get_vrt_filename init_ee() +from typing import TypedDict + + +class Params(TypedDict): + link: str # The URL link to download the image from + description: str # A description of the image + tmp_dir: str # The temporary directory to store the downloaded image + + +def get_gee_vrt( + geometry, + mosaics, + image_size, + filename: str, + bands: str, + sources, + output: cw.CustomAlert, + tmp_dir: Path, +): + filename = get_vrt_filename(filename, sources, bands, image_size) + ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True) + + # Create a filename list + descriptions = {year: f"{filename}_{year}" for year in mosaics} + + nb_points = max(1, len(ee_buffers)) + total_images = len(mosaics) * nb_points + output.reset_progress(total_images, "Progress") + + # Collect EE API results + ee_results, satellites = collect_ee_results( + mosaics, ee_buffers, descriptions, sources, bands, tmp_dir + ) + + # Download images in parallel and get the downloaded file paths + downloaded_files = download_images_in_parallel(ee_results, output) + + # Create VRT files per year using the downloaded file paths and descriptions + vrt_list = create_vrt_per_year(downloaded_files, descriptions, tmp_dir) + + # Generate title list + title_list = generate_title_list(mosaics, satellites, ee_buffers) + + # Return the file + return vrt_list, title_list + def get_ee_image( satellites: dict, @@ -79,7 +122,7 @@ def visible_pixel(ee_image: ee.Image, aoi: ee.geometry.Geometry, scale: int) -> return ee.Number(pixel_masked).divide(ee.Number(pixel)).multiply(100).getInfo() -def getImage( +def get_image( sources: Literal["sentinel", "landsat"], bands: str, aoi: ee.geometry.Geometry, @@ -114,91 +157,138 @@ def getImage( return (ee_image, satellite_id) -def get_gee_vrt( - geometry, +def collect_ee_results( mosaics, - image_size, - filename: str, - bands: str, + ee_buffers, + descriptions, sources, - output: cw.CustomAlert, - tmp_dir: Path, -): - - filename = get_vrt_filename(filename, sources, bands, image_size) - ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True) - - # create a filename list - descriptions = {} + bands, + tmp_dir, +) -> Tuple[dict[int, List[Params]], dict]: + """ + Collect Earth Engine API results for each buffer and year. + + Returns: + ee_results: A dictionary containing download parameters per year. + satellites: A dictionary tracking the satellites used per year and buffer. + """ + satellites = {} + ee_results = {} for year in mosaics: - descriptions[year] = f"{filename}_{year}" - # load the data directly in SEPAL - satellites = {} # contain the names of the used satellites - - nb_points = max(1, len(ee_buffers)) - total_images = len(mosaics) * nb_points - output.reset_progress(total_images, "Progress") - for year in mosaics: satellites[year] = [None] * len(ee_buffers) + ee_results[year] = [] - download_params = { - "sources": sources, - "bands": bands, - "ee_buffers": ee_buffers, - "year": year, - "descriptions": descriptions, - "output": output, - "satellites": satellites, - "lock": threading.Lock(), - "tmp_dir": tmp_dir, - } + for j, buffer in enumerate(ee_buffers): - # for buffer in ee_buffers: - # down_buffer(buffer, **download_params) + image, sat = get_image(sources, bands, buffer, year) + if sat is None: + print(f"Year: {year}, Buffer index: {j}") - # download the images in parralel fashion - with concurrent.futures.ThreadPoolExecutor() as executor: # use all the available CPU/GPU - # executor.map(partial(down_buffer, **download_params), ee_buffers) + satellites[year][j] = sat + description = f"{descriptions[year]}_{j}" + name = f"{description}_zipimage" + + # Get the download URL + link = image.getDownloadURL( + { + "name": name, + "region": buffer, + "filePerBand": False, + "scale": cp.getScale(sat), + } + ) + + # Store the necessary information for downloading + ee_results[year].append( + { + "link": link, + "description": description, + "tmp_dir": tmp_dir, + } + ) + + return ee_results, satellites + + +def download_images_in_parallel(ee_results: dict[int, Params], output): + """ + Download images in parallel using ThreadPoolExecutor. + + Returns: + downloaded_files: A dictionary mapping each year to a list of downloaded file paths. + """ + # Create a lock for thread-safe progress updates + progress_lock = threading.Lock() + downloaded_files = {} # To store the downloaded file paths + + for year, download_params_list in ee_results.items(): + downloaded_files[year] = [] + with concurrent.futures.ThreadPoolExecutor() as executor: futures = { - executor.submit( - partial(down_buffer, **download_params), ee_buffer - ): ee_buffer - for ee_buffer in ee_buffers + executor.submit(download_image, params, progress_lock, output): params + for params in download_params_list } - # Check if any future has raised an exception + # Collect results as they complete for future in concurrent.futures.as_completed(futures): e = future.exception() if e: raise e # Rethrow the first exception encountered - # print(satellites) + # Get the result (downloaded file path) and store it + file_path = future.result() + downloaded_files[year].append(file_path) + + return downloaded_files + + +def create_vrt_per_year(downloaded_files, descriptions, tmp_dir): + """ + Create a VRT file for each year by combining the downloaded TIFF files. + + Args: + downloaded_files: A dictionary mapping each year to a list of downloaded file paths. + descriptions: A dictionary mapping each year to its base filename. + tmp_dir: The temporary directory where files are stored. - # create a single vrt per year + Returns: + vrt_list: A dictionary mapping each year to its VRT file path. + """ vrt_list = {} - for year in mosaics: + for year, filepaths in downloaded_files.items(): + # Ensure all file paths are strings + filepaths = [str(f) for f in filepaths] - # retreive the file names - vrt_path = tmp_dir / f"{descriptions[year]}.vrt" - filepaths = [str(f) for f in tmp_dir.glob(f"{descriptions[year]}_*.tif")] + # Define the VRT path using the descriptions to match the expected filenames + vrt_filename = f"{descriptions[year]}.vrt" + vrt_path = tmp_dir / vrt_filename - # build the vrt + # Build the VRT ds = gdal.BuildVRT(str(vrt_path), filepaths) - # if there is no cahe to empty it means that one of the dataset was empty - try: - ds.FlushCache() - except AttributeError: - raise Exception(cm.export.empty_dataset) + # Check if the dataset was properly created + if ds is None: + raise Exception(f"Failed to create VRT for year {year}") + + ds = None # Close the dataset - # check that the file was effectively created (gdal doesn't raise errors) + # Ensure the VRT file exists if not vrt_path.is_file(): - raise Exception(f"the vrt {vrt_path} was not created") + raise Exception(f"The VRT {vrt_path} was not created") vrt_list[year] = vrt_path + return vrt_list + +def generate_title_list(mosaics, satellites, ee_buffers): + """ + Generate a title list mapping each year and buffer index to the satellite name. + + Returns: + title_list: A nested dictionary containing titles per year and buffer index. + """ title_list = { y: { j: f"{y} {cp.getShortname(satellites[y][j])}" @@ -206,67 +296,45 @@ def get_gee_vrt( } for y in mosaics } + return title_list - # return the file - return vrt_list, title_list - - -def down_buffer( - buffer, - sources, - bands, - ee_buffers: list, - year, - descriptions, - output: sw.Alert, - satellites, - tmp_dir: str, - lock=None, -): - """download the image for a specific buffer.""" - # get back the image index - j = ee_buffers.index(buffer) - - # get the image - image, sat = getImage(sources, bands, buffer, year) - if sat is None: - print(f"year: {year}, j: {j}") +def download_image(params: Params, progress_lock=None, output=None): + """ + Download a single image and update progress. - if lock: - with lock: - satellites[year][j] = sat + Args: + params: A dictionary containing 'link', 'description', 'tmp_dir'. + progress_lock: A threading.Lock() instance for thread-safe progress updates. + output: The output alert object to update progress. - description = f"{descriptions[year]}_{j}" + Returns: + dst: The path to the downloaded TIFF file. + """ + print(params) + link = params["link"] + description = params["description"] + tmp_dir = params["tmp_dir"] dst = tmp_dir / f"{description}.tif" if not dst.is_file(): - name = f"{description}_zipimage" - link = image.getDownloadURL( - { - "name": name, - "region": buffer, - "filePerBand": False, - "scale": cp.getScale(sat), - } - ) - tmp = tmp_dir.joinpath(f"{name}.zip") urlretrieve(link, tmp) - # unzip the file + # Unzip the file with zipfile.ZipFile(tmp, "r") as zip_: data = zip_.read(zip_.namelist()[0]) - dst.write_bytes(data) - # remove the zip + # Remove the zip file tmp.unlink() - # update the output - output.update_progress() + # Update the output progress safely (if provided) + if progress_lock and output: + with progress_lock: + output.update_progress() return dst diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..9ded720 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,140 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sepal_ui.sepalwidgets as sw" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomAlert(sw.Alert):\n", + " \"\"\"Custom alert that update the progress iteratively.\"\"\"\n", + "\n", + " total_image: int = 0\n", + "\n", + " progress_text: str = \"\"\n", + "\n", + " current_progress = 0\n", + "\n", + " def reset_progress(self, total_image=1, progress_text=\"\"):\n", + " \"\"\"rest progress and setup the totla_image value and the text.\"\"\"\n", + " self.total_image = total_image\n", + " self.progress_text = progress_text\n", + " self.current_progress = 0\n", + "\n", + " super().update_progress(0, self.progress_text, total=self.total_image)\n", + "\n", + " def update_progress(self) -> None:\n", + " \"\"\"increment the progressses by 1.\"\"\"\n", + " self.current_progress = self.current_progress + 1\n", + " return super().update_progress(\n", + " progress=self.current_progress,\n", + " msg=self.progress_text,\n", + " total=self.total_image,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert = CustomAlert()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert.reset_progress(4, \"Pdf page created\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert.total_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert.current_progress" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert.update_progress()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "clip-time-series", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/test.ipynb b/test/test.ipynb new file mode 100644 index 0000000..21dc1fe --- /dev/null +++ b/test/test.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, Path(\"../\").resolve().as_posix())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from component.scripts.gee import get_gee_vrt\n", + "from component.scripts.utils import get_pdf_path\n", + "import geopandas as gpd\n", + "import component.widget as cw\n", + "import io\n", + "import ee\n", + "from sepal_ui.scripts.utils import init_ee" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "init_ee()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "json_file = io.StringIO(\n", + " \"\"\"{\"type\": \"FeatureCollection\", \"features\": [{\"id\": \"0\", \"type\": \"Feature\", \"properties\": {\"lat\": 5.33469724544027, \"lng\": 13.0256336559457, \"id\": 1}, \"geometry\": {\"type\": \"Point\", \"coordinates\": [13.0256336559457, 5.33469724544027]}}, {\"id\": \"1\", \"type\": \"Feature\", \"properties\": {\"lat\": 5.31724397918854, \"lng\": 13.0145627442248, \"id\": 2}, \"geometry\": {\"type\": \"Point\", \"coordinates\": [13.0145627442248, 5.31724397918854]}}, {\"id\": \"2\", \"type\": \"Feature\", \"properties\": {\"lat\": 5.31816258449969, \"lng\": 13.0320916877829, \"id\": 3}, \"geometry\": {\"type\": \"Point\", \"coordinates\": [13.0320916877829, 5.31816258449969]}}, {\"id\": \"3\", \"type\": \"Feature\", \"properties\": {\"lat\": 5.48440733356101, \"lng\": 12.9075439309229, \"id\": 4}, \"geometry\": {\"type\": \"Point\", \"coordinates\": [12.9075439309229, 5.48440733356101]}}, {\"id\": \"4\", \"type\": \"Feature\", \"properties\": {\"lat\": 5.46236646346553, \"lng\": 12.9093890828764, \"id\": 5}, \"geometry\": {\"type\": \"Point\", \"coordinates\": [12.9093890828764, 5.46236646346553]}}]}\"\"\"\n", + ")\n", + "geometries = gpd.read_file(json_file)\n", + "alert = cw.CustomAlert()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c_inputs = mosaics, image_size, sources, bands, square_size = (\n", + " [2024, 2023],\n", + " 250,\n", + " [\"sentinel\", \"landsat\"],\n", + " \"Red, Green, Blue\",\n", + " 90,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tmp_dir = Path(tempfile.mkdtemp())\n", + "input_file_path = tmp_dir / \"test_points.csv\"\n", + "\n", + "pdf_filepath = get_pdf_path(\n", + " input_file_path.stem, sources, bands, square_size, image_size\n", + ")\n", + "\n", + "vrt_list, title_list = get_gee_vrt(\n", + " geometries,\n", + " mosaics,\n", + " image_size,\n", + " pdf_filepath.stem,\n", + " bands,\n", + " sources,\n", + " alert,\n", + " tmp_dir,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alert" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from component.scripts.gee import get_ee_image\n", + "from component import parameter as cp\n", + "\n", + "\n", + "satellite_id = \"sentinel_2\"\n", + "year = 2024\n", + "satellites = cp.getSatellites(sources, year)\n", + "start = str(year) + \"-01-01\"\n", + "end = str(year) + \"-12-31\"\n", + "aoi = ee.Geometry.Polygon(\n", + " [\n", + " [\n", + " [13.024513100356552, 5.333572819469696],\n", + " [13.026757769061255, 5.333572819469696],\n", + " [13.026757769061255, 5.335822103684232],\n", + " [13.024513100356552, 5.335822103684232],\n", + " [13.024513100356552, 5.333572819469696],\n", + " ]\n", + " ]\n", + ")\n", + "dataset, ee_image = get_ee_image(satellites, satellite_id, start, end, bands, aoi)\n", + "dataset.getInfo();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "satellites" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee_image.select(0).reduceRegion(\n", + " reducer=ee.Reducer.count(),\n", + " geometry=aoi,\n", + " scale=ee_image.projection().nominalScale(),\n", + ").values().get(0).getInfo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pixel_masked = (\n", + " ee_image.select(0)\n", + " .reduceRegion(\n", + " reducer=ee.Reducer.count(),\n", + " geometry=aoi,\n", + " scale=ee_image.projection().nominalScale(),\n", + " )\n", + " .get(band)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create the feature collection name\n", + "dataset = (\n", + " ee.ImageCollection(satellites[satelliteId])\n", + " .filterDate(start, end)\n", + " .filterBounds(mask)\n", + " .map(cp.getCloudMask(satelliteId))\n", + ")\n", + "\n", + "clip = dataset.median().clip(mask).select(cp.getAvailableBands()[bands][satelliteId])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from component.scripts.gee import down_buffer, getImage\n", + "from sepal_ui.scripts.utils import init_ee\n", + "from component.scripts.utils import min_diagonal\n", + "\n", + "init_ee()\n", + "import ee" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "geometry = geometries\n", + "size = image_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee_pts = [ee.Geometry.Point(*g.centroid.coords) for g in geometry.geometry]\n", + "\n", + "# get the optimal size buffer\n", + "size_list = [min_diagonal(g, size) for g in geometry.to_crs(3857).geometry]\n", + "\n", + "# create the buffers\n", + "ee_buffers = [pt.buffer(s / 2).bounds() for pt, s in zip(ee_pts, size_list)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee_buffers[0].getInfo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask = ee_buffers[0]\n", + "year = mosaics[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "getImage(sources, bands, aoi, year)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "getImage(sources, bands, mask, year)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "clip-time-series", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/test_gee.py b/test/test_gee.py index b58c5f1..2192b28 100644 --- a/test/test_gee.py +++ b/test/test_gee.py @@ -14,7 +14,12 @@ from test.gee_results import * from component import parameter as cp -from component.scripts.gee import down_buffer, get_ee_image, get_gee_vrt +from component.scripts.gee import ( + collect_ee_results, + download_image, + get_ee_image, + get_gee_vrt, +) # Test different parameters parameters = [ @@ -127,7 +132,7 @@ def test_get_ee_image(): assert dataset.getInfo() -def test_down_buffer(alert): +def test_download_image(alert): buffer = ee.Geometry.Polygon( [ @@ -146,22 +151,21 @@ def test_down_buffer(alert): ee_buffers = [buffer] year = 2021 descriptions = {2021: "test_2021"} - satellites = cp.getSatellites(sources, year) tmp_dir = Path(tempfile.mkdtemp()) alert.reset_progress(len(ee_buffers), "Progress") - image = down_buffer( - buffer, - sources, - bands, - ee_buffers, - year, - descriptions, - alert, - satellites, - tmp_dir, + ee_results, _ = collect_ee_results( + [year], ee_buffers, descriptions, sources, bands, tmp_dir ) + # Get the first year + year, params = next(iter(ee_results.items())) + + # Get the first buffer + params = params[0] + + image = download_image(params=params) + # open the output .tif image with rasterio and assert it has the right bands with rasterio.open(image) as src: array = src.read() @@ -175,22 +179,21 @@ def test_down_buffer(alert): ee_buffers = [buffer] year = 2021 descriptions = {2021: "sentinel_ndwi_2021"} - satellites = cp.getSatellites(sources, year) tmp_dir = Path(tempfile.mkdtemp()) alert.reset_progress(len(ee_buffers), "Progress") - image = down_buffer( - buffer, - sources, - bands, - ee_buffers, - year, - descriptions, - alert, - satellites, - tmp_dir, + ee_results, _ = collect_ee_results( + [year], ee_buffers, descriptions, sources, bands, tmp_dir ) + # Get the first year + year, params = next(iter(ee_results.items())) + + # Get the first buffer + params = params[0] + + image = download_image(params=params) + with rasterio.open(image) as src: array = src.read() assert array.shape[0] == 1 @@ -201,22 +204,21 @@ def test_down_buffer(alert): ee_buffers = [buffer] year = 2021 descriptions = {2021: "sentinel_ndwi_2021"} - satellites = cp.getSatellites(sources, year) tmp_dir = Path(tempfile.mkdtemp()) alert.reset_progress(len(ee_buffers), "Progress") - image = down_buffer( - buffer, - sources, - bands, - ee_buffers, - year, - descriptions, - alert, - satellites, - tmp_dir, + ee_results, _ = collect_ee_results( + [year], ee_buffers, descriptions, sources, bands, tmp_dir ) + # Get the first year + year, params = next(iter(ee_results.items())) + + # Get the first buffer + params = params[0] + + image = download_image(params=params) + with rasterio.open(image) as src: array = src.read() assert array.shape[0] == 1 diff --git a/ui.ipynb b/ui.ipynb index 8cb1c4f..6948207 100644 --- a/ui.ipynb +++ b/ui.ipynb @@ -6,6 +6,9 @@ "metadata": {}, "outputs": [], "source": [ + "from sepal_ui.scripts.utils import init_ee\n", + "\n", + "init_ee()\n", "from sepal_ui import sepalwidgets as sw\n", "from component.message import cm" ]