Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage, snowflake: do not run external queries for Snowflake when #39113

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:

hook = self.get_db_hook()

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
try:
database_info = hook.get_openlineage_database_info(connection)
Expand All @@ -334,6 +342,7 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
database_info=database_info,
database=self.database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage
Expand Down
61 changes: 50 additions & 11 deletions airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from openlineage.common.sql import DbTableMeta, SqlMeta, parse

from airflow.providers.openlineage.extractors.base import OperatorLineage
Expand All @@ -40,7 +41,6 @@
from airflow.typing_compat import TypedDict

if TYPE_CHECKING:
from openlineage.client.run import Dataset
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -104,6 +104,18 @@ class DatabaseInfo:
normalize_name_method: Callable[[str], str] = default_normalize_name_method


def from_table_meta(
table_meta: DbTableMeta, database: str | None, namespace: str, is_uppercase: bool
) -> Dataset:
if table_meta.database:
name = table_meta.qualified_name
elif database:
name = f"{database}.{table_meta.schema}.{table_meta.name}"
else:
name = f"{table_meta.schema}.{table_meta.name}"
return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper())


class SQLParser:
"""Interface for openlineage-sql.

Expand All @@ -117,7 +129,7 @@ def __init__(self, dialect: str | None = None, default_schema: str | None = None

def parse(self, sql: list[str] | str) -> SqlMeta | None:
"""Parse a single or a list of SQL statements."""
return parse(sql=sql, dialect=self.dialect)
return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema)

def parse_table_schemas(
self,
Expand Down Expand Up @@ -156,6 +168,23 @@ def parse_table_schemas(
else None,
)

def get_metadata_from_parser(
self,
inputs: list[DbTableMeta],
outputs: list[DbTableMeta],
database_info: DatabaseInfo,
namespace: str = DEFAULT_NAMESPACE,
database: str | None = None,
) -> tuple[list[Dataset], ...]:
database = database if database else database_info.database
return [
from_table_meta(dataset, database, namespace, database_info.is_uppercase_names)
for dataset in inputs
], [
from_table_meta(dataset, database, namespace, database_info.is_uppercase_names)
for dataset in outputs
]

def attach_column_lineage(
self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta
) -> None:
Expand Down Expand Up @@ -204,6 +233,7 @@ def generate_openlineage_metadata_from_sql(
database_info: DatabaseInfo,
database: str | None = None,
sqlalchemy_engine: Engine | None = None,
use_connection: bool = True,
) -> OperatorLineage:
"""Parse SQL statement(s) and generate OpenLineage metadata.

Expand Down Expand Up @@ -242,15 +272,24 @@ def generate_openlineage_metadata_from_sql(
)

namespace = self.create_namespace(database_info=database_info)
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
if use_connection:
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
else:
inputs, outputs = self.get_metadata_from_parser(
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
)

self.attach_column_lineage(outputs, database or database_info.database, parse_result)

Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,8 @@ def normalize_sql(sql: str | Iterable[str]):
sql = [stmt for stmt in sql.split(";") if stmt != ""]
sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
return ";\n".join(sql)


def should_use_external_connection(hook) -> bool:
# TODO: Add checking overrides
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"]
24 changes: 6 additions & 18 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
from contextlib import closing, contextmanager
from functools import cached_property
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
Expand Down Expand Up @@ -177,6 +178,7 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""Fetch connection params as a dict.

Expand Down Expand Up @@ -269,7 +271,7 @@ def _get_conn_params(self) -> dict[str, str | None]:

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
conn_params = self._get_conn_params()
conn_params = self._get_conn_params
return self._conn_params_to_sqlalchemy_uri(conn_params)

def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
Expand All @@ -283,7 +285,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:

def get_conn(self) -> SnowflakeConnection:
"""Return a snowflake.connection object."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
conn = connector.connect(**conn_config)
return conn

Expand All @@ -294,7 +296,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
:return: the created engine.
"""
engine_kwargs = engine_kwargs or {}
conn_params = self._get_conn_params()
conn_params = self._get_conn_params
if "insecure_mode" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["insecure_mode"] = True
Expand Down Expand Up @@ -458,21 +460,7 @@ def get_openlineage_database_dialect(self, _) -> str:
return "snowflake"

def get_openlineage_default_schema(self) -> str | None:
"""
Attempt to get current schema.

Usually ``SELECT CURRENT_SCHEMA();`` should work.
However, apparently you may set ``database`` without ``schema``
and get results from ``SELECT CURRENT_SCHEMAS();`` but not
from ``SELECT CURRENT_SCHEMA();``.
It still may return nothing if no database is set in connection.
"""
schema = self._get_conn_params()["schema"]
if not schema:
current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
if current_schemas:
_, schema = current_schemas.split(".")
return schema
return self._get_conn_params["schema"]

def _get_openlineage_authority(self, _) -> str:
from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
Expand Down Expand Up @@ -147,7 +147,7 @@ def execute_query(
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
"""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

req_id = uuid.uuid4()
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
Expand Down Expand Up @@ -186,7 +186,7 @@ def execute_query(

def get_headers(self) -> dict[str, Any]:
"""Form auth headers based on either OAuth token or JWT token from private key."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

# Use OAuth if refresh_token and client_id and client_secret are provided
if all(
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_headers(self) -> dict[str, Any]:

def get_oauth_token(self) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
Expand Down
6 changes: 4 additions & 2 deletions tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def get_db_hook(self):
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
Expand All @@ -171,7 +172,8 @@ def get_db_hook(self):
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]

Expand Down
37 changes: 14 additions & 23 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri(
):
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == expected_uri
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == expected_conn_params
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params

def test_get_conn_params_should_support_private_auth_in_connection(
self, encrypted_temporary_private_key: Path
Expand All @@ -288,7 +288,7 @@ def test_get_conn_params_should_support_private_auth_in_connection(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

@pytest.mark.parametrize("include_params", [True, False])
def test_hook_param_beats_extra(self, include_params):
Expand All @@ -311,7 +311,7 @@ def test_hook_param_beats_extra(self, include_params):
assert hook_params != extras
assert SnowflakeHook(
snowflake_conn_id="test_conn", **(hook_params if include_params else {})
)._get_conn_params() == {
)._get_conn_params == {
"user": None,
"password": "",
"application": "AIRFLOW",
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_extra_short_beats_long(self, include_unprefixed):
).get_uri(),
):
assert list(extras.values()) != list(extras_prefixed.values())
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == {
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == {
"user": None,
"password": "",
"application": "AIRFLOW",
Expand All @@ -366,7 +366,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
self, non_encrypted_temporary_private_key
Expand All @@ -384,15 +384,15 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
connection_kwargs["password"] = ""
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
connection_kwargs["password"] = _PASSWORD
with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
), pytest.raises(TypeError, match="Password was given but private key is not encrypted."):
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

def test_get_conn_params_should_fail_on_invalid_key(self):
connection_kwargs = {
Expand All @@ -419,8 +419,7 @@ def test_should_add_partner_info(self):
AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME",
):
assert (
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()["application"]
== "PARTNER_NAME"
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params["application"] == "PARTNER_NAME"
)

def test_get_conn_should_call_connect(self):
Expand All @@ -429,7 +428,7 @@ def test_get_conn_should_call_connect(self):
), mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector:
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn = hook.get_conn()
mock_connector.connect.assert_called_once_with(**hook._get_conn_params())
mock_connector.connect.assert_called_once_with(**hook._get_conn_params)
assert mock_connector.connect.return_value == conn

def test_get_sqlalchemy_engine_should_support_pass_auth(self):
Expand Down Expand Up @@ -516,7 +515,7 @@ def test_hook_parameters_should_take_precedence(self):
"session_parameters": {"AA": "AAA"},
"user": "user",
"warehouse": "TEST_WAREHOUSE",
} == hook._get_conn_params()
} == hook._get_conn_params
assert (
"snowflake://user:pw@TEST_ACCOUNT.TEST_REGION/TEST_DATABASE/TEST_SCHEMA"
"?application=AIRFLOW&authenticator=TEST_AUTH&role=TEST_ROLE&warehouse=TEST_WAREHOUSE"
Expand Down Expand Up @@ -587,22 +586,14 @@ def test_empty_sql_parameter(self):
hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"

@pytest.mark.parametrize(
"returned_schema,expected_schema",
[([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_no_schema_set(
self, mock_get_first, returned_schema, expected_schema
):
def test_get_openlineage_default_schema_with_no_schema_set(self):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"schema": None,
"schema": "PUBLIC",
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
mock_get_first.return_value = returned_schema
assert hook.get_openlineage_default_schema() == expected_schema
assert hook.get_openlineage_default_schema() == "PUBLIC"

@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
Expand Down
Loading