Skip to content

Commit

Permalink
remove legacy imports + rename functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dfguerrerom committed Sep 20, 2024
1 parent c88779b commit 76312b9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 33 deletions.
20 changes: 10 additions & 10 deletions component/scripts/gee.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def get_gee_vrt(
output.reset_progress(total_images, "Progress")

# Collect EE API results
ee_results, satellites = collect_ee_results(
ee_tasks, satellites = get_ee_tasks(
mosaics, ee_buffers, descriptions, sources, bands, tmp_dir
)

# Download images in parallel and get the downloaded file paths
downloaded_files = download_images_in_parallel(ee_results, output)
downloaded_files = download_images_in_parallel(ee_tasks, output)

# Create VRT files per year using the downloaded file paths and descriptions
vrt_list = create_vrt_per_year(downloaded_files, descriptions, tmp_dir)
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_image(
return (ee_image, satellite_id)


def collect_ee_results(
def get_ee_tasks(
mosaics,
ee_buffers,
descriptions,
Expand All @@ -169,15 +169,15 @@ def collect_ee_results(
Collect Earth Engine API results for each buffer and year.
Returns:
ee_results: A dictionary containing download parameters per year.
ee_tasks: A dictionary containing download parameters per year.
satellites: A dictionary tracking the satellites used per year and buffer.
"""
satellites = {}
ee_results = {}
ee_tasks = {}
for year in mosaics:

satellites[year] = [None] * len(ee_buffers)
ee_results[year] = []
ee_tasks[year] = []

for j, buffer in enumerate(ee_buffers):

Expand All @@ -201,18 +201,18 @@ def collect_ee_results(
)

# Store the necessary information for downloading
ee_results[year].append(
ee_tasks[year].append(
{
"link": link,
"description": description,
"tmp_dir": tmp_dir,
}
)

return ee_results, satellites
return ee_tasks, satellites


def download_images_in_parallel(ee_results: dict[int, Params], output):
def download_images_in_parallel(ee_tasks: dict[int, Params], output):
"""
Download images in parallel using ThreadPoolExecutor.
Expand All @@ -223,7 +223,7 @@ def download_images_in_parallel(ee_results: dict[int, Params], output):
progress_lock = threading.Lock()
downloaded_files = {} # To store the downloaded file paths

for year, download_params_list in ee_results.items():
for year, download_params_list in ee_tasks.items():
downloaded_files[year] = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
Expand Down
3 changes: 1 addition & 2 deletions component/scripts/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def get_planet_vrt(
}

# download the images in parralel fashion
with concurrent.futures.ThreadPoolExecutor() as executor: # use all the available CPU/GPU
# executor.map(partial(down_buffer, **download_params), ee_buffers)
with concurrent.futures.ThreadPoolExecutor() as executor:

futures = {
executor.submit(partial(get_quad, **download_params), quad_id): quad_id
Expand Down
2 changes: 1 addition & 1 deletion test/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"from component.scripts.gee import down_buffer, getImage\n",
"from component.scripts.gee import getImage\n",
"from sepal_ui.scripts.utils import init_ee\n",
"from component.scripts.utils import min_diagonal\n",
"\n",
Expand Down
14 changes: 7 additions & 7 deletions test/test_gee.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from component import parameter as cp
from component.scripts.gee import (
collect_ee_results,
download_image,
get_ee_image,
get_ee_tasks,
get_gee_vrt,
)

Expand Down Expand Up @@ -154,12 +154,12 @@ def test_download_image(alert):
tmp_dir = Path(tempfile.mkdtemp())
alert.reset_progress(len(ee_buffers), "Progress")

ee_results, _ = collect_ee_results(
ee_tasks, _ = get_ee_tasks(
[year], ee_buffers, descriptions, sources, bands, tmp_dir
)

# Get the first year
year, params = next(iter(ee_results.items()))
year, params = next(iter(ee_tasks.items()))

# Get the first buffer
params = params[0]
Expand All @@ -182,12 +182,12 @@ def test_download_image(alert):
tmp_dir = Path(tempfile.mkdtemp())
alert.reset_progress(len(ee_buffers), "Progress")

ee_results, _ = collect_ee_results(
ee_tasks, _ = get_ee_tasks(
[year], ee_buffers, descriptions, sources, bands, tmp_dir
)

# Get the first year
year, params = next(iter(ee_results.items()))
year, params = next(iter(ee_tasks.items()))

# Get the first buffer
params = params[0]
Expand All @@ -207,12 +207,12 @@ def test_download_image(alert):
tmp_dir = Path(tempfile.mkdtemp())
alert.reset_progress(len(ee_buffers), "Progress")

ee_results, _ = collect_ee_results(
ee_tasks, _ = get_ee_tasks(
[year], ee_buffers, descriptions, sources, bands, tmp_dir
)

# Get the first year
year, params = next(iter(ee_results.items()))
year, params = next(iter(ee_tasks.items()))

# Get the first buffer
params = params[0]
Expand Down
25 changes: 12 additions & 13 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
import sys

import ee
import rasterio

from component.scripts.gee import down_buffer
from component.scripts.gee import download_image, get_ee_tasks

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import tempfile
from pathlib import Path
from test.gee_results import *

from component import parameter as cp
from component.scripts.utils import enhance_band


Expand All @@ -33,22 +33,21 @@ def test_enhance_band(alert):
ee_buffers = [buffer]
year = 2021
descriptions = {2021: "test_2021"}
satellites = cp.getSatellites(sources, year)
tmp_dir = Path(tempfile.mkdtemp())
alert.reset_progress(len(ee_buffers), "Progress")

image = down_buffer(
buffer,
sources,
bands,
ee_buffers,
year,
descriptions,
alert,
satellites,
tmp_dir,
ee_tasks, _ = get_ee_tasks(
[year], ee_buffers, descriptions, sources, bands, tmp_dir
)

# Get the first year
year, params = next(iter(ee_tasks.items()))

# Get the first buffer
params = params[0]

image = download_image(params=params)

print("####### Image path", image)

# open and test the enhance function
Expand Down

0 comments on commit 76312b9

Please sign in to comment.