Skip to content

Commit

Permalink
Fix for when writing to S3 directly. #877
Browse files Browse the repository at this point in the history
  • Loading branch information
EmileSonneveld committed Nov 7, 2024
1 parent c309640 commit 18287fa
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 32 deletions.
1 change: 1 addition & 0 deletions openeogeotrellis/deploy/batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def write_stac_item_file(asset_id: str, asset: dict) -> Path:
},
}

item_file.parent.mkdir(parents=True, exist_ok=True)
with open(item_file, "wt") as fi:
json.dump(stac_item, fi, allow_nan=False)

Expand Down
2 changes: 1 addition & 1 deletion openeogeotrellis/deploy/run_graph_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_graph_locally(process_graph, output_dir):
process_graph = {"process_graph": process_graph}
run_job(
process_graph,
output_file=output_dir / "random_folder_name",
output_file=output_dir / "out", # just like in backend.py
metadata_file=output_dir / JOB_METADATA_FILENAME,
api_version="2.0.0",
job_dir=output_dir,
Expand Down
3 changes: 3 additions & 0 deletions openeogeotrellis/geopysparkdatacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,9 @@ def color_to_int(color):
if separate_asset_per_band.isDefined():
gtiff_options.setSeparateAssetPerBand(separate_asset_per_band.get())
if filepath_per_band:
if self.metadata.has_temporal_dimension():
# The user would need a way to encode the date in the filenames
raise OpenEOApiException("filepath_per_band is not supported with temporal dimension")
gtiff_options.setFilepathPerBand(get_jvm().scala.Option.apply(filepath_per_band))
gtiff_options.addHeadTag("PROCESSING_SOFTWARE",softwareversion)
if description != "":
Expand Down
107 changes: 76 additions & 31 deletions tests/test_batch_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import shutil
import subprocess
import sys
import uuid
from pathlib import Path
from unittest import mock
Expand All @@ -27,6 +29,7 @@
from openeogeotrellis.config import get_backend_config
from openeogeotrellis.deploy.batch_job import run_job
from openeogeotrellis.deploy.batch_job_metadata import extract_result_metadata
from openeogeotrellis.utils import s3_client

from .data import TEST_DATA_ROOT, get_test_data_file

Expand Down Expand Up @@ -1138,7 +1141,14 @@ def test_export_workspace_with_asset_per_band(tmp_path):
shutil.rmtree(workspace_dir, ignore_errors=True)


def test_filepath_per_band(tmp_path):
@pytest.mark.parametrize("use_S3", [True, False])
def test_filepath_per_band(
tmp_path,
use_S3,
mock_s3_bucket,
moto_server,
monkeypatch,
):
process_graph = {
"loadcollection1": {
"process_id": "load_collection",
Expand Down Expand Up @@ -1179,33 +1189,67 @@ def test_filepath_per_band(tmp_path):
},
}

process = {
"process_graph": process_graph,
}
run_job(
process,
output_file=tmp_path / "out",
metadata_file=tmp_path / JOB_METADATA_FILENAME,
api_version="2.0.0",
job_dir=tmp_path,
dependencies=[],
)
if use_S3:
monkeypatch.setenv("KUBE", "TRUE")
json_path = tmp_path / "process_graph.json"
json.dump(process_graph, json_path.open("wt"))

containing_folder = Path(__file__).parent
cmd = [
sys.executable,
containing_folder.parent / "openeogeotrellis/deploy/run_graph_locally.py",
json_path,
]
# Run in separate subprocess so that all environment variables are
# set correctly at the moment the SparkContext is created:
try:
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, env=os.environ)
except subprocess.CalledProcessError as e:
print("run_graph_locally failed. Output: " + e.output)
raise

print(output)

s3_instance = s3_client()
from openeogeotrellis.config import get_backend_config

with open(json_path, "rb") as f:
s3_instance.upload_fileobj(
f, get_backend_config().s3_bucket_name, str((tmp_path / "test.json").relative_to("/"))
)

job_dir_files = {o["Key"] for o in
s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]}
job_dir_files = [f[len(str(tmp_path)):] for f in job_dir_files]
else:
process = {
"process_graph": process_graph,
}
run_job(
process,
output_file=tmp_path / "out",
metadata_file=tmp_path / JOB_METADATA_FILENAME,
api_version="2.0.0",
job_dir=tmp_path,
dependencies=[],
)

job_dir_files = set(os.listdir(tmp_path))
job_dir_files = set(os.listdir(tmp_path))
assert len(job_dir_files) > 0
assert "folder1" in job_dir_files
assert "lat.tif" in job_dir_files
assert any(f.startswith("folder1") for f in job_dir_files)

workspace_files = list(os.listdir(tmp_path))
assert workspace_files == ListSubSet(
[
"collection.json",
"folder1",
"job_metadata.json",
"lat.tif",
"lat.tif.json",
]
)
if not use_S3:
assert "lat.tif" in workspace_files

stac_collection = pystac.Collection.from_file(str(tmp_path / "collection.json"))
stac_collection.validate_all()
Expand All @@ -1229,23 +1273,24 @@ def test_filepath_per_band(tmp_path):
assert geotiff_asset.href == "./lon.tif" # relative to the json file
assert geotiff_asset.media_type == "image/tiff; application=geotiff"
assert geotiff_asset.extra_fields["eo:bands"] == [DictSubSet({"name": "Longitude"})]
assert geotiff_asset.extra_fields["raster:bands"] == [
{
"name": "Longitude",
"statistics": {
"maximum": 0.75,
"mean": 0.375,
"minimum": 0.0,
"stddev": 0.27950849718747,
"valid_percent": 100.0,
},
}
]
if not use_S3:
assert geotiff_asset.extra_fields["raster:bands"] == [
{
"name": "Longitude",
"statistics": {
"maximum": 0.75,
"mean": 0.375,
"minimum": 0.0,
"stddev": 0.27950849718747,
"valid_percent": 100.0,
},
}
]

geotiff_asset_copy_path = tmp_path / "file.copy"
geotiff_asset.copy(str(geotiff_asset_copy_path)) # downloads the asset file
with rasterio.open(geotiff_asset_copy_path) as dataset:
assert dataset.driver == "GTiff"
geotiff_asset_copy_path = tmp_path / "file.copy"
geotiff_asset.copy(str(geotiff_asset_copy_path)) # downloads the asset file
with rasterio.open(geotiff_asset_copy_path) as dataset:
assert dataset.driver == "GTiff"


def test_discard_result(tmp_path):
Expand Down

0 comments on commit 18287fa

Please sign in to comment.