Skip to content

Commit

Permalink
feat: draft of task controller
Browse files Browse the repository at this point in the history
  • Loading branch information
dfguerrerom committed Sep 20, 2024
1 parent 0417909 commit 6c66154
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 71 deletions.
4 changes: 4 additions & 0 deletions component/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_pdf(
tmp_dir: str,
enhance_method: str = "min_max",
sources: list = [],
shared_variable=None,
):
pdf_filepath = get_pdf_path(
input_file_path.stem, sources, band_combo, image_size, enhance_method
Expand All @@ -65,6 +66,9 @@ def get_pdf(
output.reset_progress(len(buffers), "Pdf page created")
for index, r in buffers.iterrows():

if shared_variable and shared_variable.is_set():
raise Exception("Process interrupted by the user")

name = re.sub("[^a-zA-Z\\d\\-\\_]", "_", unidecode(str(r.id)))

pdf_tmp = tmp_dir / f"{pdf_filepath.stem}_tmp_pts_{name}.pdf"
Expand Down
44 changes: 36 additions & 8 deletions component/scripts/gee.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_gee_vrt(
sources,
output: cw.CustomAlert,
tmp_dir: Path,
shared_variable,
):
filename = get_vrt_filename(filename, sources, bands, image_size)
ee_buffers = get_buffers(gdf=geometry, size=image_size, gee=True)
Expand All @@ -47,16 +48,25 @@ def get_gee_vrt(

# Collect EE API results
ee_tasks, satellites = get_ee_tasks(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir, output
mosaics,
ee_buffers,
descriptions,
sources,
bands,
tmp_dir,
output,
shared_variable,
)

output.reset_progress(total_images, "Downloading images....")

# Download images in parallel and get the downloaded file paths
downloaded_files = download_images_in_parallel(ee_tasks, output)
downloaded_files = download_images_in_parallel(ee_tasks, output, shared_variable)

# Create VRT files per year using the downloaded file paths and descriptions
vrt_list = create_vrt_per_year(downloaded_files, descriptions, tmp_dir)
vrt_list = create_vrt_per_year(
downloaded_files, descriptions, tmp_dir, shared_variable
)

# Generate title list
title_list = generate_title_list(mosaics, satellites, ee_buffers)
Expand Down Expand Up @@ -160,7 +170,7 @@ def get_image(


def get_ee_tasks(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir, output
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir, output, shared_variable
) -> Tuple[dict[int, List[Params]], dict]:
"""
Collect Earth Engine API results for each buffer and year.
Expand All @@ -169,6 +179,7 @@ def get_ee_tasks(
ee_tasks: A dictionary containing download parameters per year.
satellites: A dictionary tracking the satellites used per year and buffer.
"""

satellites = {}
ee_tasks = {}
for year in mosaics:
Expand All @@ -178,6 +189,9 @@ def get_ee_tasks(

for j, buffer in enumerate(ee_buffers):

if shared_variable and shared_variable.is_set():
raise Exception("The process was interrupted by the user.")

image, sat = get_image(sources, bands, buffer, year)
if sat is None:
print(f"Year: {year}, Buffer index: {j}")
Expand Down Expand Up @@ -211,7 +225,9 @@ def get_ee_tasks(
return ee_tasks, satellites


def download_images_in_parallel(ee_tasks: dict[int, Params], output):
def download_images_in_parallel(
ee_tasks: dict[int, Params], output, shared_variable=None
):
"""
Download images in parallel using ThreadPoolExecutor.
Expand All @@ -226,7 +242,9 @@ def download_images_in_parallel(ee_tasks: dict[int, Params], output):
downloaded_files[year] = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(download_image, params, progress_lock, output): params
executor.submit(
download_image, params, progress_lock, output, shared_variable
): params
for params in download_params_list
}

Expand All @@ -243,7 +261,7 @@ def download_images_in_parallel(ee_tasks: dict[int, Params], output):
return downloaded_files


def create_vrt_per_year(downloaded_files, descriptions, tmp_dir):
def create_vrt_per_year(downloaded_files, descriptions, tmp_dir, shared_variable=None):
"""
Create a VRT file for each year by combining the downloaded TIFF files.
Expand All @@ -257,6 +275,10 @@ def create_vrt_per_year(downloaded_files, descriptions, tmp_dir):
"""
vrt_list = {}
for year, filepaths in downloaded_files.items():

if shared_variable and shared_variable.is_set():
raise Exception("The process was interrupted by the user.")

# Ensure all file paths are strings
filepaths = [str(f) for f in filepaths]

Expand Down Expand Up @@ -298,7 +320,9 @@ def generate_title_list(mosaics, satellites, ee_buffers):
return title_list


def download_image(params: Params, progress_lock=None, output=None):
def download_image(
params: Params, progress_lock=None, output=None, shared_variable=None
):
"""
Download a single image and update progress.
Expand All @@ -310,6 +334,10 @@ def download_image(params: Params, progress_lock=None, output=None):
Returns:
dst: The path to the downloaded TIFF file.
"""

if shared_variable and shared_variable.is_set():
raise Exception("The process was interrupted by the user.")

print(params)
link = params["link"]
description = params["description"]
Expand Down
6 changes: 6 additions & 0 deletions component/scripts/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_planet_vrt(
out: cw.CustomAlert,
tmp_dir: Path,
planet_model: PlanetModel,
shared_variable: threading.Event,
):

filename = get_vrt_filename(filename, ["planet"], bands, image_size)
Expand All @@ -215,6 +216,7 @@ def get_planet_vrt(
quad_ids = quads_dict[mosaic].keys()

download_params = {
"shared_variable": shared_variable,
"filename": filename,
"mosaic_quads": quads_dict[mosaic],
"mosaic_name": mosaic,
Expand Down Expand Up @@ -273,9 +275,13 @@ def get_quad(
out,
lock=None,
tmp_dir: Path = Path(tempfile.mkdtemp()),
shared_variable: threading.Event = None,
):
"""get one single quad from parameters."""
# check file existence
if shared_variable and shared_variable.is_set():
return

file = tmp_dir / f"{filename}_{mosaic_name}_{quad_id}.tif"
print("###PROCESSING FILE", file)

Expand Down
63 changes: 63 additions & 0 deletions component/scripts/task_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import threading


class TaskController:
def __init__(
self,
start_button,
stop_button,
alert,
function,
callback=None,
*function_args,
**function_kwargs,
):
self.alert = alert
self.task_thread = None
self.function = function
self.function_args = function_args
self.function_kwargs = function_kwargs
self.shared_variable = threading.Event()
self.callback = callback

self.start_button = start_button
self.stop_button = stop_button

stop_button.on_event("click", self.stop_task)

def long_running_task(self):
try:
self.alert.reset()
self.start_button.loading = True
result = self.function(
self.shared_variable, *self.function_args, **self.function_kwargs
)
if self.callback:
self.callback(result)
except Exception as e:
self.alert.append_msg(f"Error occurred: {e}", type_="error")
raise e
finally:
self.start_button.loading = False

def start_task(self, *args):
self.shared_variable.clear()
self.start_button.loading = True
self.task_thread = threading.Thread(target=self.long_running_task)
self.task_thread.start()

def stop_task(self, *args):
self.stop_button.loading = True
self.shared_variable.set()
if self.task_thread is not None:
self.task_thread.join()
self.start_button.loading = False
self.stop_button.loading = False
self.start_button.disabled = False

print("stopped")
self.alert.append_msg(
"The process was interrupted by the user.", type_="warning"
)

print("Task thread stopped.")
Loading

0 comments on commit 6c66154

Please sign in to comment.