Skip to content

Commit

Permalink
avoid gee calls in the threads
Browse files Browse the repository at this point in the history
  • Loading branch information
dfguerrerom committed Sep 20, 2024
1 parent 0c6a400 commit c88779b
Show file tree
Hide file tree
Showing 6 changed files with 656 additions and 142 deletions.
4 changes: 2 additions & 2 deletions component/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from pypdf import PdfMerger
from pypdf import PdfWriter
from sepal_ui.scripts.utils import init_ee
from unidecode import unidecode

Expand Down Expand Up @@ -171,7 +171,7 @@ def get_pdf(

# merge all the pdf files
output.add_live_msg("merge all pdf files")
merger = PdfMerger()
merger = PdfWriter()
for pdf in pdf_tmps:
merger.append(pdf)
merger.write(str(pdf_filepath))
Expand Down
278 changes: 173 additions & 105 deletions component/scripts/gee.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,67 @@
import concurrent.futures
import threading
import zipfile
from functools import partial
from pathlib import Path
from typing import Literal, Tuple
from typing import List, Literal, Tuple
from urllib.request import urlretrieve

import ee
from osgeo import gdal
from sepal_ui import sepalwidgets as sw
from sepal_ui.scripts.utils import init_ee

from component import parameter as cp
from component import widget as cw
from component.message import cm

from .utils import get_buffers, get_vrt_filename

init_ee()

from typing import TypedDict


class Params(TypedDict):
link: str # The URL link to download the image from
description: str # A description of the image
tmp_dir: str # The temporary directory to store the downloaded image


def get_gee_vrt(
geometry,
mosaics,
image_size,
filename: str,
bands: str,
sources,
output: cw.CustomAlert,
tmp_dir: Path,
):
filename = get_vrt_filename(filename, sources, bands, image_size)
ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True)

# Create a filename list
descriptions = {year: f"{filename}_{year}" for year in mosaics}

nb_points = max(1, len(ee_buffers))
total_images = len(mosaics) * nb_points
output.reset_progress(total_images, "Progress")

# Collect EE API results
ee_results, satellites = collect_ee_results(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir
)

# Download images in parallel and get the downloaded file paths
downloaded_files = download_images_in_parallel(ee_results, output)

# Create VRT files per year using the downloaded file paths and descriptions
vrt_list = create_vrt_per_year(downloaded_files, descriptions, tmp_dir)

# Generate title list
title_list = generate_title_list(mosaics, satellites, ee_buffers)

# Return the file
return vrt_list, title_list


def get_ee_image(
satellites: dict,
Expand Down Expand Up @@ -79,7 +122,7 @@ def visible_pixel(ee_image: ee.Image, aoi: ee.geometry.Geometry, scale: int) ->
return ee.Number(pixel_masked).divide(ee.Number(pixel)).multiply(100).getInfo()


def getImage(
def get_image(
sources: Literal["sentinel", "landsat"],
bands: str,
aoi: ee.geometry.Geometry,
Expand Down Expand Up @@ -114,159 +157,184 @@ def getImage(
return (ee_image, satellite_id)


def get_gee_vrt(
geometry,
def collect_ee_results(
mosaics,
image_size,
filename: str,
bands: str,
ee_buffers,
descriptions,
sources,
output: cw.CustomAlert,
tmp_dir: Path,
):

filename = get_vrt_filename(filename, sources, bands, image_size)
ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True)

# create a filename list
descriptions = {}
bands,
tmp_dir,
) -> Tuple[dict[int, List[Params]], dict]:
"""
Collect Earth Engine API results for each buffer and year.
Returns:
ee_results: A dictionary containing download parameters per year.
satellites: A dictionary tracking the satellites used per year and buffer.
"""
satellites = {}
ee_results = {}
for year in mosaics:
descriptions[year] = f"{filename}_{year}"
# load the data directly in SEPAL
satellites = {} # contain the names of the used satellites

nb_points = max(1, len(ee_buffers))
total_images = len(mosaics) * nb_points
output.reset_progress(total_images, "Progress")

for year in mosaics:
satellites[year] = [None] * len(ee_buffers)
ee_results[year] = []

download_params = {
"sources": sources,
"bands": bands,
"ee_buffers": ee_buffers,
"year": year,
"descriptions": descriptions,
"output": output,
"satellites": satellites,
"lock": threading.Lock(),
"tmp_dir": tmp_dir,
}
for j, buffer in enumerate(ee_buffers):

# for buffer in ee_buffers:
# down_buffer(buffer, **download_params)
image, sat = get_image(sources, bands, buffer, year)
if sat is None:
print(f"Year: {year}, Buffer index: {j}")

# download the images in parralel fashion
with concurrent.futures.ThreadPoolExecutor() as executor: # use all the available CPU/GPU
# executor.map(partial(down_buffer, **download_params), ee_buffers)
satellites[year][j] = sat

description = f"{descriptions[year]}_{j}"
name = f"{description}_zipimage"

# Get the download URL
link = image.getDownloadURL(
{
"name": name,
"region": buffer,
"filePerBand": False,
"scale": cp.getScale(sat),
}
)

# Store the necessary information for downloading
ee_results[year].append(
{
"link": link,
"description": description,
"tmp_dir": tmp_dir,
}
)

return ee_results, satellites


def download_images_in_parallel(ee_results: dict[int, Params], output):
"""
Download images in parallel using ThreadPoolExecutor.
Returns:
downloaded_files: A dictionary mapping each year to a list of downloaded file paths.
"""
# Create a lock for thread-safe progress updates
progress_lock = threading.Lock()
downloaded_files = {} # To store the downloaded file paths

for year, download_params_list in ee_results.items():
downloaded_files[year] = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(
partial(down_buffer, **download_params), ee_buffer
): ee_buffer
for ee_buffer in ee_buffers
executor.submit(download_image, params, progress_lock, output): params
for params in download_params_list
}

# Check if any future has raised an exception
# Collect results as they complete
for future in concurrent.futures.as_completed(futures):
e = future.exception()
if e:
raise e # Rethrow the first exception encountered

# print(satellites)
# Get the result (downloaded file path) and store it
file_path = future.result()
downloaded_files[year].append(file_path)

return downloaded_files


def create_vrt_per_year(downloaded_files, descriptions, tmp_dir):
"""
Create a VRT file for each year by combining the downloaded TIFF files.
Args:
downloaded_files: A dictionary mapping each year to a list of downloaded file paths.
descriptions: A dictionary mapping each year to its base filename.
tmp_dir: The temporary directory where files are stored.
# create a single vrt per year
Returns:
vrt_list: A dictionary mapping each year to its VRT file path.
"""
vrt_list = {}
for year in mosaics:
for year, filepaths in downloaded_files.items():
# Ensure all file paths are strings
filepaths = [str(f) for f in filepaths]

# retreive the file names
vrt_path = tmp_dir / f"{descriptions[year]}.vrt"
filepaths = [str(f) for f in tmp_dir.glob(f"{descriptions[year]}_*.tif")]
# Define the VRT path using the descriptions to match the expected filenames
vrt_filename = f"{descriptions[year]}.vrt"
vrt_path = tmp_dir / vrt_filename

# build the vrt
# Build the VRT
ds = gdal.BuildVRT(str(vrt_path), filepaths)

# if there is no cahe to empty it means that one of the dataset was empty
try:
ds.FlushCache()
except AttributeError:
raise Exception(cm.export.empty_dataset)
# Check if the dataset was properly created
if ds is None:
raise Exception(f"Failed to create VRT for year {year}")

ds = None # Close the dataset

# check that the file was effectively created (gdal doesn't raise errors)
# Ensure the VRT file exists
if not vrt_path.is_file():
raise Exception(f"the vrt {vrt_path} was not created")
raise Exception(f"The VRT {vrt_path} was not created")

vrt_list[year] = vrt_path
return vrt_list


def generate_title_list(mosaics, satellites, ee_buffers):
"""
Generate a title list mapping each year and buffer index to the satellite name.
Returns:
title_list: A nested dictionary containing titles per year and buffer index.
"""
title_list = {
y: {
j: f"{y} {cp.getShortname(satellites[y][j])}"
for j in range(len(ee_buffers))
}
for y in mosaics
}
return title_list

# return the file
return vrt_list, title_list


def down_buffer(
buffer,
sources,
bands,
ee_buffers: list,
year,
descriptions,
output: sw.Alert,
satellites,
tmp_dir: str,
lock=None,
):
"""download the image for a specific buffer."""
# get back the image index
j = ee_buffers.index(buffer)

# get the image
image, sat = getImage(sources, bands, buffer, year)

if sat is None:
print(f"year: {year}, j: {j}")
def download_image(params: Params, progress_lock=None, output=None):
"""
Download a single image and update progress.
if lock:
with lock:
satellites[year][j] = sat
Args:
params: A dictionary containing 'link', 'description', 'tmp_dir'.
progress_lock: A threading.Lock() instance for thread-safe progress updates.
output: The output alert object to update progress.
description = f"{descriptions[year]}_{j}"
Returns:
dst: The path to the downloaded TIFF file.
"""
print(params)
link = params["link"]
description = params["description"]
tmp_dir = params["tmp_dir"]

dst = tmp_dir / f"{description}.tif"

if not dst.is_file():

name = f"{description}_zipimage"

link = image.getDownloadURL(
{
"name": name,
"region": buffer,
"filePerBand": False,
"scale": cp.getScale(sat),
}
)

tmp = tmp_dir.joinpath(f"{name}.zip")
urlretrieve(link, tmp)

# unzip the file
# Unzip the file
with zipfile.ZipFile(tmp, "r") as zip_:
data = zip_.read(zip_.namelist()[0])

dst.write_bytes(data)

# remove the zip
# Remove the zip file
tmp.unlink()

# update the output
output.update_progress()
# Update the output progress safely (if provided)
if progress_lock and output:
with progress_lock:
output.update_progress()

return dst
Loading

0 comments on commit c88779b

Please sign in to comment.