From 3c2d8dd1e155cbec16f2e7886c878b7a9bad2d5f Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 16 Apr 2024 15:15:45 +0200 Subject: [PATCH] fix: OpenLineage datasets in GCSTimeSpanFileTransformOperator Signed-off-by: Kacper Muda --- .../providers/google/cloud/operators/gcs.py | 69 +++++++----- .../google/cloud/operators/test_gcs.py | 105 ++++++++++++------ 2 files changed, 115 insertions(+), 59 deletions(-) diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index c311c8b4ed6b8..6c72378a436f3 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -774,8 +774,8 @@ def __init__( self.upload_continue_on_fail = upload_continue_on_fail self.upload_num_attempts = upload_num_attempts - self._source_object_names: list[str] = [] - self._destination_object_names: list[str] = [] + self._source_prefix_interp: str | None = None + self._destination_prefix_interp: str | None = None def execute(self, context: Context) -> list[str]: # Define intervals and prefixes. @@ -803,11 +803,11 @@ def execute(self, context: Context) -> list[str]: timespan_start = timespan_start.in_timezone(timezone.utc) timespan_end = timespan_end.in_timezone(timezone.utc) - source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( + self._source_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( self.source_prefix, timespan_start, ) - destination_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( + self._destination_prefix_interp = GCSTimeSpanFileTransformOperator.interpolate_prefix( self.destination_prefix, timespan_start, ) @@ -828,9 +828,9 @@ def execute(self, context: Context) -> list[str]: ) # Fetch list of files. - self._source_object_names = source_hook.list_by_timespan( + blobs_to_transform = source_hook.list_by_timespan( bucket_name=self.source_bucket, - prefix=source_prefix_interp, + prefix=self._source_prefix_interp, timespan_start=timespan_start, timespan_end=timespan_end, ) @@ -840,7 +840,7 @@ def execute(self, context: Context) -> list[str]: temp_output_dir_path = Path(temp_output_dir) # TODO: download in parallel. - for blob_to_transform in self._source_object_names: + for blob_to_transform in blobs_to_transform: destination_file = temp_input_dir_path / blob_to_transform destination_file.parent.mkdir(parents=True, exist_ok=True) try: @@ -877,6 +877,8 @@ def execute(self, context: Context) -> list[str]: self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path) + files_uploaded = [] + # TODO: upload in parallel. for upload_file in temp_output_dir_path.glob("**/*"): if upload_file.is_dir(): @@ -884,8 +886,8 @@ def execute(self, context: Context) -> list[str]: upload_file_name = str(upload_file.relative_to(temp_output_dir_path)) - if self.destination_prefix is not None: - upload_file_name = f"{destination_prefix_interp}/{upload_file_name}" + if self._destination_prefix_interp is not None: + upload_file_name = f"{self._destination_prefix_interp.rstrip('/')}/{upload_file_name}" self.log.info("Uploading file %s to %s", upload_file, upload_file_name) @@ -897,35 +899,46 @@ def execute(self, context: Context) -> list[str]: chunk_size=self.chunk_size, num_max_attempts=self.upload_num_attempts, ) - self._destination_object_names.append(str(upload_file_name)) + files_uploaded.append(str(upload_file_name)) except GoogleCloudError: if not self.upload_continue_on_fail: raise - return self._destination_object_names + return files_uploaded def get_openlineage_facets_on_complete(self, task_instance): - """Implement on_complete as execute() resolves object names.""" + """Implement on_complete as execute() resolves object prefixes.""" from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage - input_datasets = [ - Dataset( - namespace=f"gs://{self.source_bucket}", - name=object_name, - ) - for object_name in self._source_object_names - ] - output_datasets = [ - Dataset( - namespace=f"gs://{self.destination_bucket}", - name=object_name, - ) - for object_name in self._destination_object_names - ] - - return OperatorLineage(inputs=input_datasets, outputs=output_datasets) + def _parse_prefix(pref): + # Use parent if not a file (dot not in name) and not a dir (ends with slash) + if "." not in pref.split("/")[-1] and not pref.endswith("/"): + pref = Path(pref).parent.as_posix() + return "/" if pref in (".", "/", "") else pref.rstrip("/") + + input_prefix, output_prefix = "/", "/" + if self._source_prefix_interp is not None: + input_prefix = _parse_prefix(self._source_prefix_interp) + + if self._destination_prefix_interp is not None: + output_prefix = _parse_prefix(self._destination_prefix_interp) + + return OperatorLineage( + inputs=[ + Dataset( + namespace=f"gs://{self.source_bucket}", + name=input_prefix, + ) + ], + outputs=[ + Dataset( + namespace=f"gs://{self.destination_bucket}", + name=output_prefix, + ) + ], + ) class GCSDeleteBucketOperator(GoogleCloudBaseOperator): diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 2eb96682bdb0e..6236aa5f23ba2 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -21,6 +21,7 @@ from pathlib import Path from unittest import mock +import pytest from openlineage.client.facet import ( LifecycleStateChange, LifecycleStateChangeDatasetFacet, @@ -483,15 +484,78 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir): ] ) + @pytest.mark.parametrize( + ("source_prefix", "dest_prefix", "inputs", "outputs"), + ( + ( + None, + None, + [Dataset(f"gs://{TEST_BUCKET}", "/")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "/")], + ), + ( + None, + "dest_pre/", + [Dataset(f"gs://{TEST_BUCKET}", "/")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")], + ), + ( + "source_pre/", + None, + [Dataset(f"gs://{TEST_BUCKET}", "source_pre")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "/")], + ), + ( + "source_pre/", + "dest_pre/", + [Dataset(f"gs://{TEST_BUCKET}", "source_pre")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "dest_pre")], + ), + ( + "source_pre", + "dest_pre", + [Dataset(f"gs://{TEST_BUCKET}", "/")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "/")], + ), + ( + "dir1/source_pre", + "dir2/dest_pre", + [Dataset(f"gs://{TEST_BUCKET}", "dir1")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "dir2")], + ), + ( + "", + "/", + [Dataset(f"gs://{TEST_BUCKET}", "/")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "/")], + ), + ( + "source/a.txt", + "target/", + [Dataset(f"gs://{TEST_BUCKET}", "source/a.txt")], + [Dataset(f"gs://{TEST_BUCKET}_dest", "target")], + ), + ), + ids=( + "no prefixes", + "dest prefix only", + "source prefix only", + "both with ending slash", + "both without ending slash", + "both as directory with prefix", + "both empty or root", + "source prefix is file path", + ), + ) @mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory") @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess") @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") - def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mock_tempdir): + def test_get_openlineage_facets_on_complete( + self, mock_hook, mock_subprocess, mock_tempdir, source_prefix, dest_prefix, inputs, outputs + ): source_bucket = TEST_BUCKET - source_prefix = "source_prefix" destination_bucket = TEST_BUCKET + "_dest" - destination_prefix = "destination_prefix" destination = "destination" file1 = "file1" @@ -508,8 +572,8 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo mock_tempdir.return_value.__enter__.side_effect = ["source", destination] mock_hook.return_value.list_by_timespan.return_value = [ - f"{source_prefix}/{file1}", - f"{source_prefix}/{file2}", + f"{source_prefix or ''}{file1}", + f"{source_prefix or ''}{file2}", ] mock_proc = mock.MagicMock() @@ -529,7 +593,7 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo source_prefix=source_prefix, source_gcp_conn_id="", destination_bucket=destination_bucket, - destination_prefix=destination_prefix, + destination_prefix=dest_prefix, destination_gcp_conn_id="", transform_script="script.py", ) @@ -541,32 +605,11 @@ def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mo ] op.execute(context=context) - expected_inputs = [ - Dataset( - namespace=f"gs://{source_bucket}", - name=f"{source_prefix}/{file1}", - ), - Dataset( - namespace=f"gs://{source_bucket}", - name=f"{source_prefix}/{file2}", - ), - ] - expected_outputs = [ - Dataset( - namespace=f"gs://{destination_bucket}", - name=f"{destination_prefix}/{file1}", - ), - Dataset( - namespace=f"gs://{destination_bucket}", - name=f"{destination_prefix}/{file2}", - ), - ] - lineage = op.get_openlineage_facets_on_complete(None) - assert len(lineage.inputs) == 2 - assert len(lineage.outputs) == 2 - assert lineage.inputs == expected_inputs - assert lineage.outputs == expected_outputs + assert len(lineage.inputs) == len(inputs) + assert len(lineage.outputs) == len(outputs) + assert sorted(lineage.inputs) == sorted(inputs) + assert sorted(lineage.outputs) == sorted(outputs) class TestGCSDeleteBucketOperator: