Skip to content

Commit

Permalink
feat: Add openlineage support for CopyFromExternalStageToSnowflakeOpe…
Browse files Browse the repository at this point in the history
…rator (#36535)
  • Loading branch information
kacpermuda authored Jan 8, 2024
1 parent 98f5ce2 commit 3dc99d8
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 4 deletions.
163 changes: 160 additions & 3 deletions airflow/providers/snowflake/transfers/copy_into_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ def __init__(
self.copy_options = copy_options
self.validation_mode = validation_mode

self.hook: SnowflakeHook | None = None
self._sql: str | None = None
self._result: list[dict[str, Any]] = []

def execute(self, context: Any) -> None:
snowflake_hook = SnowflakeHook(
self.hook = SnowflakeHook(
snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse,
database=self.database,
Expand All @@ -127,7 +131,7 @@ def execute(self, context: Any) -> None:
if self.columns_array:
into = f"{into}({', '.join(self.columns_array)})"

sql = f"""
self._sql = f"""
COPY INTO {into}
FROM @{self.stage}/{self.prefix or ""}
{"FILES=(" + ",".join(map(enclose_param, self.files)) + ")" if self.files else ""}
Expand All @@ -137,5 +141,158 @@ def execute(self, context: Any) -> None:
{self.validation_mode or ""}
"""
self.log.info("Executing COPY command...")
snowflake_hook.run(sql=sql, autocommit=self.autocommit)
self._result = self.hook.run( # type: ignore # mypy does not work well with return_dictionaries=True
sql=self._sql,
autocommit=self.autocommit,
handler=lambda x: x.fetchall(),
return_dictionaries=True,
)
self.log.info("COPY command completed")

@staticmethod
def _extract_openlineage_unique_dataset_paths(
query_result: list[dict[str, Any]],
) -> tuple[list[tuple[str, str]], list[str]]:
"""Extracts and returns unique OpenLineage dataset paths and file paths that failed to be parsed.
Each row in the results is expected to have a 'file' field, which is a URI.
The function parses these URIs and constructs a set of unique OpenLineage (namespace, name) tuples.
Additionally, it captures any URIs that cannot be parsed or processed
and returns them in a separate error list.
For Azure, Snowflake has a unique way of representing URI:
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
that is transformed by this function to a Dataset with more universal naming convention:
Dataset(namespace="wasbs://container_name@account_name", name="path/to"), as described at
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage
:param query_result: A list of dictionaries, each containing a 'file' key with a URI value.
:return: Two lists - the first is a sorted list of tuples, each representing a unique dataset path,
and the second contains any URIs that cannot be parsed or processed correctly.
>>> method = CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths
>>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
>>> method(results)
([('wasbs://azure_container@my_account', 'dir3')], [])
>>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container"}]
>>> method(results)
([('wasbs://azure_container@my_account', '/')], [])
>>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, {"file": "s3://bucket/a.csv"}]
>>> method(results)
([('gcs://bucket', '/'), ('s3://bucket', '/')], [])
>>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": "gcs://bucket/dir/dir2/a.txt"}]
>>> method(results)
([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])
>>> results = [
... {"file": "s3://bucket/dir/file.csv"},
... {"file": "azure://my_account.something_new.windows.net/azure_container"},
... ]
>>> method(results)
([('s3://bucket', 'dir')], ['azure://my_account.something_new.windows.net/azure_container'])
"""
import re
from pathlib import Path
from urllib.parse import urlparse

azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
extraction_error_files = []
unique_dataset_paths = set()

for row in query_result:
uri = urlparse(row["file"])
if uri.scheme == "azure":
match = re.fullmatch(azure_regex, row["file"])
if not match:
extraction_error_files.append(row["file"])
continue
account_name, container_name, name = match.groups()
namespace = f"wasbs://{container_name}@{account_name}"
else:
namespace = f"{uri.scheme}://{uri.netloc}"
name = uri.path.lstrip("/")

name = Path(name).parent.as_posix()
if name in ("", "."):
name = "/"

unique_dataset_paths.add((namespace, name))

return sorted(unique_dataset_paths), sorted(extraction_error_files)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement _on_complete because we rely on return value of a query."""
import re

from openlineage.client.facet import (
ExternalQueryRunFacet,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

if not self._sql:
return OperatorLineage()

query_results = self._result or []
# If no files were uploaded we get [{"status": "0 files were uploaded..."}]
if len(query_results) == 1 and query_results[0].get("status"):
query_results = []
unique_dataset_paths, extraction_error_files = self._extract_openlineage_unique_dataset_paths(
query_results
)
input_datasets = [Dataset(namespace=namespace, name=name) for namespace, name in unique_dataset_paths]

run_facets = {}
if extraction_error_files:
self.log.debug(
f"Unable to extract Dataset namespace and name "
f"for the following files: `{extraction_error_files}`."
)
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=len(query_results),
failedTasks=len(extraction_error_files),
errors=[
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task=file_uri,
taskNumber=None,
)
for file_uri in extraction_error_files
],
)

connection = self.hook.get_connection(getattr(self.hook, str(self.hook.conn_name_attr)))
database_info = self.hook.get_openlineage_database_info(connection)

dest_name = self.table
schema = self.hook.get_openlineage_default_schema()
database = database_info.database
if schema:
dest_name = f"{schema}.{dest_name}"
if database:
dest_name = f"{database}.{dest_name}"

snowflake_namespace = SQLParser.create_namespace(database_info)
query = SQLParser.normalize_sql(self._sql)
query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))

run_facets["externalQuery"] = ExternalQueryRunFacet(
externalQueryId=self.hook.query_ids[0], source=snowflake_namespace
)

return OperatorLineage(
inputs=input_datasets,
outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)],
job_facets={"sql": SqlJobFacet(query=query)},
run_facets=run_facets,
)
168 changes: 167 additions & 1 deletion tests/providers/snowflake/transfers/test_copy_into_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@
# under the License.
from __future__ import annotations

from typing import Callable
from unittest import mock

from openlineage.client.facet import (
ExternalQueryRunFacet,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from pytest import mark

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo
from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator


Expand Down Expand Up @@ -62,4 +74,158 @@ def test_execute(self, mock_hook):
validation_mode
"""

mock_hook.return_value.run.assert_called_once_with(sql=sql, autocommit=True)
mock_hook.return_value.run.assert_called_once_with(
sql=sql, autocommit=True, return_dictionaries=True, handler=mock.ANY
)

handler = mock_hook.return_value.run.mock_calls[0].kwargs.get("handler")
assert isinstance(handler, Callable)

@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
mock_hook().run.return_value = [
{"file": "s3://aws_bucket_name/dir1/file.csv"},
{"file": "s3://aws_bucket_name_2"},
{"file": "gcs://gcs_bucket_name/dir2/file.csv"},
{"file": "gcs://gcs_bucket_name_2"},
{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"},
{"file": "azure://my_account.blob.core.windows.net/azure_container_2"},
]
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_inputs = [
Dataset(namespace="gcs://gcs_bucket_name", name="dir2"),
Dataset(namespace="gcs://gcs_bucket_name_2", name="/"),
Dataset(namespace="s3://aws_bucket_name", name="dir1"),
Dataset(namespace="s3://aws_bucket_name_2", name="/"),
Dataset(namespace="wasbs://azure_container@my_account", name="dir3"),
Dataset(namespace="wasbs://azure_container_2@my_account", name="/"),
]
expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=expected_inputs,
outputs=expected_outputs,
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
)
},
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)

@mark.parametrize("rows", (None, []))
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete_with_empty_inputs(self, mock_hook, rows):
mock_hook().run.return_value = rows
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=[],
outputs=expected_outputs,
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
)
},
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)

@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hook):
mock_hook().run.return_value = [
{"file": "s3://aws_bucket_name/dir1/file.csv"},
{"file": "gs://gcp_bucket_name/dir2/file.csv"},
{"file": "azure://my_account.weird-url.net/azure_container/dir3/file.csv"},
{"file": "azure://my_account.another_weird-url.net/con/file.csv"},
]
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_inputs = [
Dataset(namespace="gs://gcp_bucket_name", name="dir2"),
Dataset(namespace="s3://aws_bucket_name", name="dir1"),
]
expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""
expected_run_facets = {
"extractionError": ExtractionErrorRunFacet(
totalTasks=4,
failedTasks=2,
errors=[
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task="azure://my_account.another_weird-url.net/con/file.csv",
taskNumber=None,
),
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task="azure://my_account.weird-url.net/azure_container/dir3/file.csv",
taskNumber=None,
),
],
),
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
),
}

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=expected_inputs,
outputs=expected_outputs,
run_facets=expected_run_facets,
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)

0 comments on commit 3dc99d8

Please sign in to comment.