Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
dfguerrerom committed Oct 25, 2024
2 parents 5b77b85 + 5665408 commit be67fa4
Show file tree
Hide file tree
Showing 10 changed files with 727 additions and 162 deletions.
4 changes: 2 additions & 2 deletions component/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
from pypdf import PdfMerger
from pypdf import PdfWriter
from sepal_ui.scripts.utils import init_ee
from unidecode import unidecode

Expand Down Expand Up @@ -171,7 +171,7 @@ def get_pdf(

# merge all the pdf files
output.add_live_msg("merge all pdf files")
merger = PdfMerger()
merger = PdfWriter()
for pdf in pdf_tmps:
merger.append(pdf)
merger.write(str(pdf_filepath))
Expand Down
279 changes: 173 additions & 106 deletions component/scripts/gee.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,69 @@
import concurrent.futures
import threading
import zipfile
from functools import partial
from pathlib import Path
from typing import Literal, Tuple
from typing import List, Literal, Tuple
from urllib.request import urlretrieve

import ee
from osgeo import gdal
from sepal_ui import sepalwidgets as sw
from sepal_ui.scripts.utils import init_ee

from component import parameter as cp
from component import widget as cw
from component.message import cm

from .utils import get_buffers, get_vrt_filename

init_ee()

from typing import TypedDict


class Params(TypedDict):
link: str # The URL link to download the image from
description: str # A description of the image
tmp_dir: str # The temporary directory to store the downloaded image


def get_gee_vrt(
geometry,
mosaics,
image_size,
filename: str,
bands: str,
sources,
output: cw.CustomAlert,
tmp_dir: Path,
):
filename = get_vrt_filename(filename, sources, bands, image_size)
ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True)

# Create a filename list
descriptions = {year: f"{filename}_{year}" for year in mosaics}

nb_points = max(1, len(ee_buffers))
total_images = len(mosaics) * nb_points
output.reset_progress(total_images, "Requesting images....")

# Collect EE API results
ee_tasks, satellites = get_ee_tasks(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir, output
)

output.reset_progress(total_images, "Downloading images....")

# Download images in parallel and get the downloaded file paths
downloaded_files = download_images_in_parallel(ee_tasks, output)

# Create VRT files per year using the downloaded file paths and descriptions
vrt_list = create_vrt_per_year(downloaded_files, descriptions, tmp_dir)

# Generate title list
title_list = generate_title_list(mosaics, satellites, ee_buffers)

# Return the file
return vrt_list, title_list


def get_ee_image(
satellites: dict,
Expand Down Expand Up @@ -79,7 +124,7 @@ def visible_pixel(ee_image: ee.Image, aoi: ee.geometry.Geometry, scale: int) ->
return ee.Number(pixel_masked).divide(ee.Number(pixel)).multiply(100).getInfo()


def getImage(
def get_image(
sources: Literal["sentinel", "landsat"],
bands: str,
aoi: ee.geometry.Geometry,
Expand Down Expand Up @@ -114,159 +159,181 @@ def getImage(
return (ee_image, satellite_id)


def get_gee_vrt(
geometry,
mosaics,
image_size,
filename: str,
bands: str,
sources,
output: cw.CustomAlert,
tmp_dir: Path,
):

filename = get_vrt_filename(filename, sources, bands, image_size)
ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True)
def get_ee_tasks(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir, output
) -> Tuple[dict[int, List[Params]], dict]:
"""
Collect Earth Engine API results for each buffer and year.
# create a filename list
descriptions = {}
Returns:
ee_tasks: A dictionary containing download parameters per year.
satellites: A dictionary tracking the satellites used per year and buffer.
"""
satellites = {}
ee_tasks = {}
for year in mosaics:
descriptions[year] = f"{filename}_{year}"
# load the data directly in SEPAL
satellites = {} # contain the names of the used satellites

nb_points = max(1, len(ee_buffers))
total_images = len(mosaics) * nb_points
output.reset_progress(total_images, "Progress")

for year in mosaics:
satellites[year] = [None] * len(ee_buffers)
ee_tasks[year] = []

download_params = {
"sources": sources,
"bands": bands,
"ee_buffers": ee_buffers,
"year": year,
"descriptions": descriptions,
"output": output,
"satellites": satellites,
"lock": threading.Lock(),
"tmp_dir": tmp_dir,
}
for j, buffer in enumerate(ee_buffers):

# for buffer in ee_buffers:
# down_buffer(buffer, **download_params)
image, sat = get_image(sources, bands, buffer, year)
if sat is None:
print(f"Year: {year}, Buffer index: {j}")

# download the images in parralel fashion
with concurrent.futures.ThreadPoolExecutor() as executor: # use all the available CPU/GPU
# executor.map(partial(down_buffer, **download_params), ee_buffers)
satellites[year][j] = sat

description = f"{descriptions[year]}_{j}"
name = f"{description}_zipimage"

# Get the download URL
link = image.getDownloadURL(
{
"name": name,
"region": buffer,
"filePerBand": False,
"scale": cp.getScale(sat),
}
)

# Store the necessary information for downloading
ee_tasks[year].append(
{
"link": link,
"description": description,
"tmp_dir": tmp_dir,
}
)

output.update_progress()

return ee_tasks, satellites


def download_images_in_parallel(ee_tasks: dict[int, Params], output):
"""
Download images in parallel using ThreadPoolExecutor.
Returns:
downloaded_files: A dictionary mapping each year to a list of downloaded file paths.
"""
# Create a lock for thread-safe progress updates
progress_lock = threading.Lock()
downloaded_files = {} # To store the downloaded file paths

for year, download_params_list in ee_tasks.items():
downloaded_files[year] = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(
partial(down_buffer, **download_params), ee_buffer
): ee_buffer
for ee_buffer in ee_buffers
executor.submit(download_image, params, progress_lock, output): params
for params in download_params_list
}

# Check if any future has raised an exception
# Collect results as they complete
for future in concurrent.futures.as_completed(futures):
e = future.exception()
if e:
raise e # Rethrow the first exception encountered

# print(satellites)
# Get the result (downloaded file path) and store it
file_path = future.result()
downloaded_files[year].append(file_path)

return downloaded_files


def create_vrt_per_year(downloaded_files, descriptions, tmp_dir):
"""
Create a VRT file for each year by combining the downloaded TIFF files.
# create a single vrt per year
Args:
downloaded_files: A dictionary mapping each year to a list of downloaded file paths.
descriptions: A dictionary mapping each year to its base filename.
tmp_dir: The temporary directory where files are stored.
Returns:
vrt_list: A dictionary mapping each year to its VRT file path.
"""
vrt_list = {}
for year in mosaics:
for year, filepaths in downloaded_files.items():
# Ensure all file paths are strings
filepaths = [str(f) for f in filepaths]

# retreive the file names
vrt_path = tmp_dir / f"{descriptions[year]}.vrt"
filepaths = [str(f) for f in tmp_dir.glob(f"{descriptions[year]}_*.tif")]
# Define the VRT path using the descriptions to match the expected filenames
vrt_filename = f"{descriptions[year]}.vrt"
vrt_path = tmp_dir / vrt_filename

# build the vrt
# Build the VRT
ds = gdal.BuildVRT(str(vrt_path), filepaths)

# if there is no cahe to empty it means that one of the dataset was empty
try:
ds.FlushCache()
except AttributeError:
raise Exception(cm.export.empty_dataset)
# Check if the dataset was properly created
if ds is None:
raise Exception(f"Failed to create VRT for year {year}")

ds = None # Close the dataset

# check that the file was effectively created (gdal doesn't raise errors)
# Ensure the VRT file exists
if not vrt_path.is_file():
raise Exception(f"the vrt {vrt_path} was not created")
raise Exception(f"The VRT {vrt_path} was not created")

vrt_list[year] = vrt_path
return vrt_list


def generate_title_list(mosaics, satellites, ee_buffers):
"""
Generate a title list mapping each year and buffer index to the satellite name.
Returns:
title_list: A nested dictionary containing titles per year and buffer index.
"""
title_list = {
y: {
j: f"{y} {cp.getShortname(satellites[y][j])}"
for j in range(len(ee_buffers))
}
for y in mosaics
}

# return the file
return vrt_list, title_list
return title_list


def down_buffer(
buffer,
sources,
bands,
ee_buffers: list,
year,
descriptions,
output: sw.Alert,
satellites,
tmp_dir: str,
lock=None,
):
"""download the image for a specific buffer."""
# get back the image index
j = ee_buffers.index(buffer)
def download_image(params: Params, progress_lock=None, output=None):
"""
Download a single image and update progress.
# get the image
image, sat = getImage(sources, bands, buffer, year)
Args:
params: A dictionary containing 'link', 'description', 'tmp_dir'.
progress_lock: A threading.Lock() instance for thread-safe progress updates.
output: The output alert object to update progress.
if sat is None:
print(f"year: {year}, j: {j}")

if lock:
with lock:
satellites[year][j] = sat

description = f"{descriptions[year]}_{j}"
Returns:
dst: The path to the downloaded TIFF file.
"""
print(params)
link = params["link"]
description = params["description"]
tmp_dir = params["tmp_dir"]

dst = tmp_dir / f"{description}.tif"

if not dst.is_file():

name = f"{description}_zipimage"

link = image.getDownloadURL(
{
"name": name,
"region": buffer,
"filePerBand": False,
"scale": cp.getScale(sat),
}
)

tmp = tmp_dir.joinpath(f"{name}.zip")
urlretrieve(link, tmp)

# unzip the file
# Unzip the file
with zipfile.ZipFile(tmp, "r") as zip_:
data = zip_.read(zip_.namelist()[0])

dst.write_bytes(data)

# remove the zip
# Remove the zip file
tmp.unlink()

# update the output
output.update_progress()
# Update the output progress safely (if provided)
if progress_lock and output:
with progress_lock:
output.update_progress()

return dst
3 changes: 1 addition & 2 deletions component/scripts/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def get_planet_vrt(
}

# download the images in parralel fashion
with concurrent.futures.ThreadPoolExecutor() as executor: # use all the available CPU/GPU
# executor.map(partial(down_buffer, **download_params), ee_buffers)
with concurrent.futures.ThreadPoolExecutor() as executor:

futures = {
executor.submit(partial(get_quad, **download_params), quad_id): quad_id
Expand Down
Loading

0 comments on commit be67fa4

Please sign in to comment.