Skip to content

Commit

Permalink
Allow for files in subdirectories to keep their relative paths in exp…
Browse files Browse the repository at this point in the history
…ort_workspace. #877
  • Loading branch information
EmileSonneveld committed Nov 14, 2024
1 parent 5e0f8cb commit cc375f4
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 102 deletions.
22 changes: 12 additions & 10 deletions openeogeotrellis/deploy/batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def run_job(
result,
result_metadata,
result_assets_metadata=result_assets_metadata,
stac_metadata_dir=job_dir,
job_dir=job_dir,
remove_exported_assets=job_options.get("remove-exported-assets", False),
)
finally:
Expand Down Expand Up @@ -525,7 +525,7 @@ def _export_to_workspaces(
result: SaveResult,
result_metadata: dict,
result_assets_metadata: dict,
stac_metadata_dir: Path,
job_dir: Path,
remove_exported_assets: bool,
):
workspace_repository: WorkspaceRepository = backend_config_workspace_repository
Expand All @@ -536,9 +536,7 @@ def _export_to_workspaces(

stac_hrefs = [
f"file:{path}"
for path in _write_exported_stac_collection(
stac_metadata_dir, result_metadata, list(result_assets_metadata.keys())
)
for path in _write_exported_stac_collection(job_dir, result_metadata, list(result_assets_metadata.keys()))
]

workspace_uris = {}
Expand All @@ -556,11 +554,15 @@ def _export_to_workspaces(
remove_original = remove_exported_assets and final_export

export_to_workspace = partial(
_export_to_workspace, target=workspace, merge=merge, remove_original=remove_original
_export_to_workspace,
job_dir=job_dir,
target=workspace,
merge=merge,
remove_original=remove_original,
)

for stac_href in stac_hrefs:
export_to_workspace(stac_href)
export_to_workspace(source_uri=stac_href)

for asset_key, asset in result_assets_metadata.items():
workspace_uri = export_to_workspace(source_uri=asset["href"])
Expand All @@ -587,13 +589,13 @@ def _export_to_workspaces(
result_metadata["assets"][asset_key]["alternate"] = alternate


def _export_to_workspace(source_uri: str, target: Workspace, merge: str, remove_original: bool) -> str:
def _export_to_workspace(job_dir: str, source_uri: str, target: Workspace, merge: str, remove_original: bool) -> str:
uri_parts = urlparse(source_uri)

if not uri_parts.scheme or uri_parts.scheme.lower() == "file":
return target.import_file(Path(uri_parts.path), merge, remove_original)
return target.import_file(job_dir, Path(uri_parts.path), merge, remove_original)
elif uri_parts.scheme == "s3":
return target.import_object(source_uri, merge, remove_original)
return target.import_object(job_dir, source_uri, merge, remove_original)
else:
raise ValueError(f"unsupported scheme {uri_parts.scheme} for {source_uri}; supported are: file, s3")

Expand Down
17 changes: 10 additions & 7 deletions openeogeotrellis/workspace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os
from pathlib import Path
from typing import Union
from urllib.parse import urlparse

from boto3.s3.transfer import TransferConfig

from openeo_driver.utils import remove_slash_prefix
from openeo_driver.workspace import Workspace

from openeogeotrellis.utils import s3_client
Expand All @@ -19,14 +21,15 @@ class ObjectStorageWorkspace(Workspace):
def __init__(self, bucket: str):
self.bucket = bucket

def import_file(self, file: Path, merge: str, remove_original: bool = False) -> str:
def import_file(self, common_path: Union[str, Path], file: Path, merge: str, remove_original: bool = False) -> str:
merge = os.path.normpath(merge)
subdirectory = merge[1:] if merge.startswith("/") else merge
subdirectory = remove_slash_prefix(merge)
file_relative = file.relative_to(common_path)

MB = 1024 ** 2
config = TransferConfig(multipart_threshold=self.MULTIPART_THRESHOLD_IN_MB * MB)

key = subdirectory + "/" + file.name
key = f"{subdirectory}/{file_relative}"
s3_client().upload_file(str(file), self.bucket, key, Config=config)

if remove_original:
Expand All @@ -37,17 +40,17 @@ def import_file(self, file: Path, merge: str, remove_original: bool = False) ->
_log.debug(f"{'moved' if remove_original else 'uploaded'} {file.absolute()} to {workspace_uri}")
return workspace_uri

def import_object(self, s3_uri: str, merge: str, remove_original: bool = False) -> str:
def import_object(self, common_path: str, s3_uri: str, merge: str, remove_original: bool = False) -> str:
uri_parts = urlparse(s3_uri)

if not uri_parts.scheme or uri_parts.scheme.lower() != "s3":
raise ValueError(s3_uri)

source_bucket = uri_parts.netloc
source_key = uri_parts.path[1:]
filename = source_key.split("/")[-1]
source_key = remove_slash_prefix(uri_parts.path[1:])
file_relative = Path(source_key).relative_to(remove_slash_prefix(common_path))

target_key = f"{merge}/{filename}"
target_key = f"{merge}/{file_relative}"

s3 = s3_client()
s3.copy_object(CopySource={"Bucket": source_bucket, "Key": source_key}, Bucket=self.bucket, Key=target_key)
Expand Down
4 changes: 3 additions & 1 deletion tests/backend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from openeo_driver.workspace import DiskWorkspace

from openeogeotrellis.config import GpsBackendConfig
from openeogeotrellis.workspace import ObjectStorageWorkspace

oidc_providers = [
OidcProvider(
Expand All @@ -27,7 +28,8 @@
os.makedirs("/tmp/workspace", exist_ok=True)
workspaces = {
"tmp_workspace": DiskWorkspace(root_directory=Path("/tmp/workspace")),
"tmp": DiskWorkspace(root_directory=Path("/tmp"))
"tmp": DiskWorkspace(root_directory=Path("/tmp")),
"s3_workspace": ObjectStorageWorkspace(bucket="openeo-fake-bucketname"),
}


Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def is_port_free(port: int) -> bool:
return s.connect_ex(("localhost", port)) != 0


def force_restart_spark_context():
def force_stop_spark_context():
# Restart SparkContext will make sure that the new environment variables are available inside the JVM
# This is a hacky way to allow debugging in the same process.
from pyspark import SparkContext
Expand Down
4 changes: 2 additions & 2 deletions tests/deploy/test_batch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,9 +1369,9 @@ def test_run_job_to_s3(
from openeogeotrellis.configparams import ConfigParams

if ConfigParams().use_object_storage:
from tests.conftest import force_restart_spark_context
from tests.conftest import force_stop_spark_context

force_restart_spark_context()
force_stop_spark_context()

# Run in the same process, so that we can check the output directly:
from openeogeotrellis.deploy.run_graph_locally import run_graph_locally
Expand Down
180 changes: 99 additions & 81 deletions tests/test_batch_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from openeogeotrellis.deploy.batch_job import run_job
from openeogeotrellis.deploy.batch_job_metadata import extract_result_metadata
from openeogeotrellis.utils import s3_client
from .conftest import force_stop_spark_context, _setup_local_spark

from .data import TEST_DATA_ROOT, get_test_data_file

Expand Down Expand Up @@ -1141,14 +1142,21 @@ def test_export_workspace_with_asset_per_band(tmp_path):
shutil.rmtree(workspace_dir, ignore_errors=True)


@pytest.mark.parametrize("use_s3", [False]) # use_s3 does not work on Jenkins
@pytest.mark.parametrize("use_s3", [False]) # use_s3 is only for debugging locally. Does not work on Jenkins
def test_filepath_per_band(
tmp_path,
use_s3,
mock_s3_bucket,
moto_server,
monkeypatch,
):
if use_s3:
workspace_id = "s3_workspace"
else:
workspace_id = "tmp_workspace"

merge = _random_merge()

process_graph = {
"loadcollection1": {
"process_id": "load_collection",
Expand Down Expand Up @@ -1185,43 +1193,31 @@ def test_filepath_per_band(
"filepath_per_band": ["folder1/lon.tif", "lat.tif"],
},
},
"result": False,
},
"exportworkspace1": {
"process_id": "export_workspace",
"arguments": {
"data": {"from_node": "saveresult1"},
"workspace": workspace_id,
"merge": merge,
},
"result": True,
},
}

if use_s3:
monkeypatch.setenv("KUBE", "TRUE")
json_path = tmp_path / "process_graph.json"
json.dump(process_graph, json_path.open("w"))

containing_folder = Path(__file__).parent
cmd = [
sys.executable,
containing_folder.parent / "openeogeotrellis/deploy/run_graph_locally.py",
json_path,
]
# Run in separate subprocess so that all environment variables are
# set correctly at the moment the SparkContext is created:
try:
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, env=os.environ)
except subprocess.CalledProcessError as e:
print("run_graph_locally failed. Output: " + e.output)
raise
force_stop_spark_context() # only use this when running a single test

print(output)
class TerminalReporterMock:
@staticmethod
def write_line(message):
print(message)

_setup_local_spark(TerminalReporterMock(), 0)
s3_instance = s3_client()
from openeogeotrellis.config import get_backend_config

with open(json_path, "rb") as f:
s3_instance.upload_fileobj(
f, get_backend_config().s3_bucket_name, str((tmp_path / "test.json").relative_to("/"))
)

job_dir_files = {o["Key"] for o in
s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]}
job_dir_files = [f[len(str(tmp_path)):] for f in job_dir_files]
else:
try:
process = {
"process_graph": process_graph,
}
Expand All @@ -1234,61 +1230,83 @@ def test_filepath_per_band(
dependencies=[],
)

if use_s3:
job_dir_files = {
o["Key"] for o in s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]
}
print(job_dir_files)
job_dir_files = set(os.listdir(tmp_path))
assert len(job_dir_files) > 0
assert "lat.tif" in job_dir_files
assert any(f.startswith("folder1") for f in job_dir_files)

workspace_files = list(os.listdir(tmp_path))
assert workspace_files == ListSubSet(
[
"collection.json",
"folder1",
"job_metadata.json",
"lat.tif.json",
]
)
if not use_s3:
assert "lat.tif" in workspace_files

stac_collection = pystac.Collection.from_file(str(tmp_path / "collection.json"))
stac_collection.validate_all()
item_links = [item_link for item_link in stac_collection.links if item_link.rel == "item"]
assert len(item_links) == 2
item_link = item_links[0]

assert item_link.media_type == "application/geo+json"
assert item_link.href == "./folder1/lon.tif.json"

items = list(stac_collection.get_items())
assert len(items) == 2

item = items[0]
assert item.id == "folder1/lon.tif"

geotiff_asset = item.get_assets()["folder1/lon.tif"]
assert "data" in geotiff_asset.roles
assert geotiff_asset.href == "./lon.tif" # relative to the json file
assert geotiff_asset.media_type == "image/tiff; application=geotiff"
assert geotiff_asset.extra_fields["eo:bands"] == [DictSubSet({"name": "Longitude"})]
if not use_s3:
assert geotiff_asset.extra_fields["raster:bands"] == [
{
"name": "Longitude",
"statistics": {
"maximum": 0.75,
"mean": 0.375,
"minimum": 0.0,
"stddev": 0.27950849718747,
"valid_percent": 100.0,
},
}
]
assert len(job_dir_files) > 0
assert "lat.tif" in job_dir_files
assert any(f.startswith("folder1") for f in job_dir_files)

geotiff_asset_copy_path = tmp_path / "file.copy"
geotiff_asset.copy(str(geotiff_asset_copy_path)) # downloads the asset file
with rasterio.open(geotiff_asset_copy_path) as dataset:
assert dataset.driver == "GTiff"
stac_collection = pystac.Collection.from_file(str(tmp_path / "collection.json"))
stac_collection.validate_all()
item_links = [item_link for item_link in stac_collection.links if item_link.rel == "item"]
assert len(item_links) == 2
item_link = item_links[0]

assert item_link.media_type == "application/geo+json"
assert item_link.href == "./folder1/lon.tif.json"

items = list(stac_collection.get_items())
assert len(items) == 2

item = items[0]
assert item.id == "folder1/lon.tif"

geotiff_asset = item.get_assets()["folder1/lon.tif"]
assert "data" in geotiff_asset.roles
assert geotiff_asset.href == "./lon.tif" # relative to the json file
assert geotiff_asset.media_type == "image/tiff; application=geotiff"
assert geotiff_asset.extra_fields["eo:bands"] == [DictSubSet({"name": "Longitude"})]
if not use_s3:
assert geotiff_asset.extra_fields["raster:bands"] == [
{
"name": "Longitude",
"statistics": {
"maximum": 0.75,
"mean": 0.375,
"minimum": 0.0,
"stddev": 0.27950849718747,
"valid_percent": 100.0,
},
}
]

geotiff_asset_copy_path = tmp_path / "file.copy"
geotiff_asset.copy(str(geotiff_asset_copy_path)) # downloads the asset file
with rasterio.open(geotiff_asset_copy_path) as dataset:
assert dataset.driver == "GTiff"

workspace = get_backend_config().workspaces[workspace_id]
if use_s3:
# job bucket and workspace bucket are the same
job_dir_files_s3 = [
o["Key"] for o in s3_instance.list_objects(Bucket=get_backend_config().s3_bucket_name)["Contents"]
]
assert job_dir_files_s3 == ListSubSet(
[
f"{merge}/collection.json",
f"{merge}/folder1/lon.tif",
f"{merge}/folder1/lon.tif.json",
f"{merge}/lat.tif",
f"{merge}/lat.tif.json",
]
)

else:
workspace_dir = Path(f"{workspace.root_directory}/{merge}")
assert workspace_dir.exists()
assert (workspace_dir / "lat.tif").exists()
assert (workspace_dir / "folder1/lon.tif").exists()
stac_collection_exported = pystac.Collection.from_file(str(workspace_dir / "collection.json"))
stac_collection_exported.validate_all()
finally:
if not use_s3:
workspace_dir = Path(f"{workspace.root_directory}/{merge}")
shutil.rmtree(workspace_dir, ignore_errors=True)


def test_discard_result(tmp_path):
Expand Down

0 comments on commit cc375f4

Please sign in to comment.