Skip to content

Commit

Permalink
feat: add OpenLineage support for transfer operators between gcs and …
Browse files Browse the repository at this point in the history
…local

Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed Nov 27, 2024
1 parent b1a44b4 commit 7018877
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 18 deletions.
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/common/io/assets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDa
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset

parsed = urllib.parse.urlsplit(asset.uri)
return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path)
return OpenLineageDataset(namespace=f"file://{parsed.netloc}" if parsed.netloc else "file", name=parsed.path)
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ def execute(self, context: Context):
raise AirflowException("The size of the downloaded file is too large to push to XCom!")
else:
hook.download(bucket_name=self.bucket, object_name=self.object_name, filename=self.filename)

def get_openlineage_facets_on_start(self):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage

return OperatorLineage(
inputs=[Dataset(namespace=f"gs://{self.bucket}", name=self.object_name)],
outputs=[Dataset(namespace="file", name=self.filename)] if self.filename else [],
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ class LocalFilesystemToGCSOperator(BaseOperator):
def __init__(
self,
*,
src,
dst,
bucket,
gcp_conn_id="google_cloud_default",
mime_type="application/octet-stream",
gzip=False,
src: str | list[str],
dst: str,
bucket: str,
gcp_conn_id: str = "google_cloud_default",
mime_type: str = "application/octet-stream",
gzip: bool = False,
chunk_size: int | None = None,
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
Expand Down Expand Up @@ -120,3 +120,38 @@ def execute(self, context: Context):
gzip=self.gzip,
chunk_size=self.chunk_size,
)

def get_openlineage_facets_on_start(self):
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Identifier,
SymlinksDatasetFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path, WILDCARD

source_facets = {}
if isinstance(self.src, str): # Single path provided, possibly relative or with wildcard
original_src = f"{self.src}"
absolute_src = os.path.abspath(self.src)
resolved_src = extract_ds_name_from_gcs_path(absolute_src)
if original_src.startswith("/") and not resolved_src.startswith("/"):
resolved_src = "/" + resolved_src
source_objects = [resolved_src]

if WILDCARD in original_src or absolute_src != resolved_src:
# We attach a symlink with unmodified path.
source_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace="file", name=original_src, type="file")]
),
}
else:
source_objects = self.src

dest_object = self.dst if os.path.basename(self.dst) else extract_ds_name_from_gcs_path(self.dst)

return OperatorLineage(
inputs=[Dataset(namespace="file", name=src, facets=source_facets) for src in source_objects],
outputs=[Dataset(namespace=f"gs://{self.bucket}", name=dest_object)],
)
4 changes: 2 additions & 2 deletions providers/tests/common/io/assets/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def test_file_asset():
@pytest.mark.parametrize(
("uri", "ol_dataset"),
(
("file:///valid/path", OpenLineageDataset(namespace="file://", name="/valid/path")),
("file:///valid/path", OpenLineageDataset(namespace="file", name="/valid/path")),
(
"file://127.0.0.1:8080/dir/file.csv",
OpenLineageDataset(namespace="file://127.0.0.1:8080", name="/dir/file.csv"),
),
("file:///C://dir/file", OpenLineageDataset(namespace="file://", name="/C://dir/file")),
("file:///C://dir/file", OpenLineageDataset(namespace="file", name="/C://dir/file")),
),
)
def test_convert_asset_to_openlineage(uri, ol_dataset):
Expand Down
17 changes: 17 additions & 0 deletions providers/tests/google/cloud/transfers/test_gcs_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,20 @@ def test_xcom_encoding(self, mock_hook):
bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
)
context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, value=FILE_CONTENT_STR)

def test_get_openlineage_facets_on_start_(self):
operator = GCSToLocalFilesystemOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
object_name=TEST_OBJECT,
filename=LOCAL_FILE_PATH,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == 1
assert result.outputs[0].namespace == "file"
assert result.outputs[0].name == LOCAL_FILE_PATH
assert result.inputs[0].namespace == f"gs://{TEST_BUCKET}"
assert result.inputs[0].name == TEST_OBJECT
85 changes: 76 additions & 9 deletions providers/tests/google/cloud/transfers/test_local_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@

from airflow.models.dag import DAG
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
from airflow.providers.common.compat.openlineage.facet import (
Identifier,
SymlinksDatasetFacet,
)

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -72,7 +76,7 @@ def test_init(self):
def test_execute(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src=self.testfile1,
dst="test/test1.csv",
Expand All @@ -91,7 +95,7 @@ def test_execute(self, mock_hook):
@pytest.mark.db_test
def test_execute_with_empty_src(self):
operator = LocalFilesystemToGCSOperator(
task_id="local_to_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src="no_file.txt",
dst="test/no_file.txt",
Expand All @@ -104,7 +108,7 @@ def test_execute_with_empty_src(self):
def test_execute_multiple(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor", dag=self.dag, src=self.testfiles, dst="test/", **self._config
task_id="file_to_gcs_operator", dag=self.dag, src=self.testfiles, dst="test/", **self._config
)
operator.execute(None)
files_objects = zip(
Expand All @@ -127,7 +131,7 @@ def test_execute_multiple(self, mock_hook):
def test_execute_wildcard(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor", dag=self.dag, src="/tmp/fake*.csv", dst="test/", **self._config
task_id="file_to_gcs_operator", dag=self.dag, src="/tmp/fake*.csv", dst="test/", **self._config
)
operator.execute(None)
object_names = ["test/" + os.path.basename(fp) for fp in glob("/tmp/fake*.csv")]
Expand All @@ -145,17 +149,80 @@ def test_execute_wildcard(self, mock_hook):
]
mock_instance.upload.assert_has_calls(calls)

@pytest.mark.parametrize(
("src", "dst"),
[
("/tmp/fake*.csv", "test/test1.csv"),
("/tmp/fake*.csv", "test"),
("/tmp/fake*.csv", "test/dir"),
],
)
@mock.patch("airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook", autospec=True)
def test_execute_negative(self, mock_hook):
def test_execute_negative(self, mock_hook, src, dst):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src="/tmp/fake*.csv",
dst="test/test1.csv",
src=src,
dst=dst,
**self._config,
)
print(glob("/tmp/fake*.csv"))
with pytest.raises(ValueError):
operator.execute(None)
mock_instance.assert_not_called()


@pytest.mark.parametrize(
("src", "dst", "expected_input", "expected_output", "symlink"),
[
("/tmp/fake*.csv", "test/", "/tmp", "test", True),
("/tmp/../tmp/fake*.csv", "test/", "/tmp", "test", True),
("/tmp/fake1.csv", "test/test1.csv", "/tmp/fake1.csv", "test/test1.csv", False),
("/tmp/fake1.csv", "test/pre", "/tmp/fake1.csv", "test/pre", False),
],
)
def test_get_openlineage_facets_on_start_with_string_src(self, src, dst, expected_input, expected_output, symlink):
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
dag=self.dag,
src=src,
dst=dst,
**self._config,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == 1
assert result.outputs[0].name == expected_output
assert result.inputs[0].name == expected_input
if symlink:
assert result.inputs[0].facets["symlink"] == SymlinksDatasetFacet(
identifiers=[Identifier(namespace="file", name=src, type="file")]
)


@pytest.mark.parametrize(
("src", "dst", "expected_inputs", "expected_output"),
[
(["/tmp/fake1.csv", "/tmp/fake2.csv"], "test/", ["/tmp/fake1.csv", "/tmp/fake2.csv"], "test"),
(["/tmp/fake1.csv", "/tmp/fake2.csv"], "", ["/tmp/fake1.csv", "/tmp/fake2.csv"], "/"),
],
)
def test_get_openlineage_facets_on_start_with_string_src(self, src, dst, expected_inputs, expected_output):
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
dag=self.dag,
src=src,
dst=dst,
**self._config,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == len(expected_inputs)
assert result.outputs[0].name == expected_output
assert result.outputs[0].namespace == "gs://dummy"
assert all(inp.name in expected_inputs for inp in result.inputs)
assert all(inp.namespace == "file" for inp in result.inputs)

0 comments on commit 7018877

Please sign in to comment.