Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenLineage] Fix datasets in GCSTimeSpanFileTransformOperator #39064

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,8 @@ def __init__(
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts

self._source_object_names: list[str] = []
self._destination_object_names: list[str] = []
self._source_prefix_interp: str | None = None
self._destination_prefix_interp: str | None = None

def execute(self, context: Context) -> list[str]:
# Define intervals and prefixes.
Expand Down Expand Up @@ -803,11 +803,11 @@ def execute(self, context: Context) -> list[str]:
timespan_start = timespan_start.in_timezone(timezone.utc)
timespan_end = timespan_end.in_timezone(timezone.utc)

source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix(
self._source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix(
self.source_prefix,
timespan_start,
)
destination_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix(
self._destination_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix(
self.destination_prefix,
timespan_start,
)
Expand All @@ -828,9 +828,9 @@ def execute(self, context: Context) -> list[str]:
)

# Fetch list of files.
self._source_object_names = source_hook.list_by_timespan(
blobs_to_transform = source_hook.list_by_timespan(
bucket_name=self.source_bucket,
prefix=source_prefix_interp,
prefix=self._source_prefix_interp,
timespan_start=timespan_start,
timespan_end=timespan_end,
)
Expand All @@ -840,7 +840,7 @@ def execute(self, context: Context) -> list[str]:
temp_output_dir_path = Path(temp_output_dir)

# TODO: download in parallel.
for blob_to_transform in self._source_object_names:
for blob_to_transform in blobs_to_transform:
destination_file = temp_input_dir_path / blob_to_transform
destination_file.parent.mkdir(parents=True, exist_ok=True)
try:
Expand Down Expand Up @@ -877,15 +877,17 @@ def execute(self, context: Context) -> list[str]:

self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path)

files_uploaded = []

# TODO: upload in parallel.
for upload_file in temp_output_dir_path.glob("**/*"):
if upload_file.is_dir():
continue

upload_file_name = str(upload_file.relative_to(temp_output_dir_path))

if self.destination_prefix is not None:
upload_file_name = f"{destination_prefix_interp}/{upload_file_name}"
if self._destination_prefix_interp is not None:
upload_file_name = f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}"

self.log.info("Uploading file %s to %s", upload_file, upload_file_name)

Expand All @@ -897,35 +899,46 @@ def execute(self, context: Context) -> list[str]:
chunk_size=self.chunk_size,
num_max_attempts=self.upload_num_attempts,
)
self._destination_object_names.append(str(upload_file_name))
files_uploaded.append(str(upload_file_name))
except GoogleCloudError:
if not self.upload_continue_on_fail:
raise

return self._destination_object_names
return files_uploaded

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as execute() resolves object names."""
"""Implement on_complete as execute() resolves object prefixes."""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

input_datasets = [
Dataset(
namespace=f"gs://{self.source_bucket}",
name=object_name,
)
for object_name in self._source_object_names
]
output_datasets = [
Dataset(
namespace=f"gs://{self.destination_bucket}",
name=object_name,
)
for object_name in self._destination_object_names
]

return OperatorLineage(inputs=input_datasets, outputs=output_datasets)
def _parse_prefix(pref):
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
return "/" if pref in (".", "/", "") else pref.rstrip("/")

input_prefix, output_prefix = "/", "/"
if self._source_prefix_interp is not None:
input_prefix = _parse_prefix(self._source_prefix_interp)

if self._destination_prefix_interp is not None:
output_prefix = _parse_prefix(self._destination_prefix_interp)

return OperatorLineage(
inputs=[
Dataset(
namespace=f"gs://{self.source_bucket}",
name=input_prefix,
)
],
outputs=[
Dataset(
namespace=f"gs://{self.destination_bucket}",
name=output_prefix,
)
],
)


class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
Expand Down
105 changes: 74 additions & 31 deletions tests/providers/google/cloud/operators/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from unittest import mock

import pytest
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
Expand Down Expand Up @@ -483,15 +484,78 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir):
]
)

@pytest.mark.parametrize(
("source_prefix", "dest_prefix", "inputs", "outputs"),
(
(
None,
None,
[Dataset(f"gs://{TEST_BUCKET}", "/")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
),
(
None,
"dest_pre/",
[Dataset(f"gs://{TEST_BUCKET}", "/")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
),
(
"source_pre/",
None,
[Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
),
(
"source_pre/",
"dest_pre/",
[Dataset(f"gs://{TEST_BUCKET}", "source_pre")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")],
),
(
"source_pre",
"dest_pre",
[Dataset(f"gs://{TEST_BUCKET}", "/")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
),
(
"dir1/source_pre",
"dir2/dest_pre",
[Dataset(f"gs://{TEST_BUCKET}", "dir1")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "dir2")],
),
(
"",
"/",
[Dataset(f"gs://{TEST_BUCKET}", "/")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "/")],
),
(
"source/a.txt",
"target/",
[Dataset(f"gs://{TEST_BUCKET}", "source/a.txt")],
[Dataset(f"gs://{TEST_BUCKET}_dest", "target")],
),
),
ids=(
"no prefixes",
"dest prefix only",
"source prefix only",
"both with ending slash",
"both without ending slash",
"both as directory with prefix",
"both empty or root",
"source prefix is file path",
),
)
@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
@mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mock_tempdir):
def test_get_openlineage_facets_on_complete(
self, mock_hook, mock_subprocess, mock_tempdir, source_prefix, dest_prefix, inputs, outputs
):
source_bucket = TEST_BUCKET
source_prefix = "source_prefix"

destination_bucket = TEST_BUCKET + "_dest"
destination_prefix = "destination_prefix"
destination = "destination"

file1 = "file1"
Expand All @@ -508,8 +572,8 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo

mock_tempdir.return_value.__enter__.side_effect = ["source", destination]
mock_hook.return_value.list_by_timespan.return_value = [
f"{source_prefix}/{file1}",
f"{source_prefix}/{file2}",
f"{source_prefix or ''}{file1}",
f"{source_prefix or ''}{file2}",
]

mock_proc = mock.MagicMock()
Expand All @@ -529,7 +593,7 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo
source_prefix=source_prefix,
source_gcp_conn_id="",
destination_bucket=destination_bucket,
destination_prefix=destination_prefix,
destination_prefix=dest_prefix,
destination_gcp_conn_id="",
transform_script="script.py",
)
Expand All @@ -541,32 +605,11 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo
]
op.execute(context=context)

expected_inputs = [
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file2}",
),
]
expected_outputs = [
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file2}",
),
]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 2
assert lineage.inputs == expected_inputs
assert lineage.outputs == expected_outputs
assert len(lineage.inputs) == len(inputs)
assert len(lineage.outputs) == len(outputs)
assert sorted(lineage.inputs) == sorted(inputs)
assert sorted(lineage.outputs) == sorted(outputs)


class TestGCSDeleteBucketOperator:
Expand Down