Skip to content

Commit

Permalink
Deprecate SnowflakeSqlApiOperatorAsync (#1447)
Browse files Browse the repository at this point in the history
Deprecate SnowflakeSqlApiOperatorAsync and fallback to its Airflow OSS provider's counterpart SnowflakeSqlApiOperator with deferrable=True

closes: #1415
---------

Co-authored-by: Wei Lee <[email protected]>
Co-authored-by: Pankaj Koti <[email protected]>
Co-authored-by: Pankaj Koti <[email protected]>
  • Loading branch information
4 people authored Jan 29, 2024
1 parent 84cfc6b commit 6167579
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 353 deletions.
12 changes: 12 additions & 0 deletions astronomer/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import uuid
import warnings
from datetime import timedelta
from pathlib import Path
from typing import Any
Expand All @@ -17,6 +18,9 @@

class SnowflakeSqlApiHookAsync(SnowflakeHook):
"""
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook` instead.
A client to interact with Snowflake using SQL API and allows submitting
multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL
statements for execution, poll to check the status of the execution of a statement. Fetch query results
Expand Down Expand Up @@ -58,6 +62,14 @@ def __init__(
*args: Any,
**kwargs: Any,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use `airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook` instead"
),
DeprecationWarning,
stacklevel=2,
)
self.snowflake_conn_id = snowflake_conn_id
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta
Expand Down
213 changes: 18 additions & 195 deletions astronomer/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,24 @@

import logging
import typing
import warnings
from contextlib import closing
from datetime import timedelta
from typing import Any, Callable, List

import requests
from airflow.exceptions import AirflowException

from snowflake.connector import SnowflakeConnection
from snowflake.connector.constants import QueryStatus

try:
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
except ImportError: # pragma: no cover
# For apache-airflow-providers-snowflake > 3.3.0
# currently added type: ignore[no-redef, attr-defined] and pragma: no cover because this import
# path won't be available in current setup
from airflow.providers.common.sql.operators.sql import ( # type: ignore[assignment]
SQLExecuteQueryOperator as SnowflakeOperator,
)
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator, SnowflakeSqlApiOperator

from astronomer.providers.snowflake.hooks.snowflake import (
SnowflakeHookAsync,
fetch_all_snowflake_handler,
)
from astronomer.providers.snowflake.hooks.snowflake_sql_api import (
SnowflakeSqlApiHookAsync,
)
from astronomer.providers.snowflake.triggers.snowflake_trigger import (
SnowflakeSqlApiTrigger,
SnowflakeTrigger,
get_db_hook,
)
from astronomer.providers.utils.typing_compat import Context
from snowflake.connector import SnowflakeConnection
from snowflake.connector.constants import QueryStatus


def _check_queries_finish(conn: SnowflakeConnection, query_ids: list[str]) -> bool:
Expand Down Expand Up @@ -224,183 +209,21 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] |
raise AirflowException("Did not receive valid event from the trigerrer")


class SnowflakeSqlApiOperatorAsync(SnowflakeOperator):
class SnowflakeSqlApiOperatorAsync(SnowflakeSqlApiOperator):
"""
This class is deprecated.
Use :class: `~airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` instead
and set `deferrable` param to `True` instead.
"""
Implemented Async Snowflake SQL API Operator to support multiple SQL statements sequentially,
which is the behavior of the SnowflakeOperator, the Snowflake SQL API allows submitting
multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL
statements for execution, poll to check the status of the execution of a statement. Fetch query results
concurrently.
This Operator currently uses key pair authentication, so you need tp provide private key raw content or
private key file path in the snowflake connection along with other details
.. seealso::
`Snowflake SQL API key pair Authentication <https://docs.snowflake.com/en/developer-guide/sql-api/authenticating.html#label-sql-api-authenticating-key-pair>`_
Where can this operator fit in?
- To execute multiple SQL statements in a single request
- To execute the SQL statement asynchronously and to execute standard queries and most DDL and DML statements
- To develop custom applications and integrations that perform queries
- To create provision users and roles, create table, etc.
The following commands are not supported:
- The PUT command (in Snowflake SQL)
- The GET command (in Snowflake SQL)
- The CALL command with stored procedures that return a table(stored procedures with the RETURNS TABLE clause).
.. seealso::
- `Snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#introduction-to-the-sql-api>`_
- `API Reference <https://docs.snowflake.com/en/developer-guide/sql-api/reference.html#snowflake-sql-api-reference>`_
- `Limitation on snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#limitations-of-the-sql-api>`_
:param snowflake_conn_id: Reference to Snowflake connection id
:param sql: the sql code to be executed. (templated)
:param autocommit: if True, each command is automatically committed.
(default value: True)
:param parameters: (optional) the parameters to render the SQL query with.
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:param database: name of database (will overwrite database defined
in connection)
:param schema: name of schema (will overwrite schema defined in
connection)
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param poll_interval: the interval in seconds to poll the query
:param statement_count: Number of SQL statement to be executed
:param token_life_time: lifetime of the JWT Token
:param token_renewal_delta: Renewal time of the JWT Token
:param bindings: (Optional) Values of bind variables in the SQL statement.
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
""" # noqa

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes

def __init__(
self,
*,
snowflake_conn_id: str = "snowflake_default",
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
schema: str | None = None,
authenticator: str | None = None,
session_parameters: dict[str, Any] | None = None,
poll_interval: int = 5,
statement_count: int = 0,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
bindings: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.snowflake_conn_id = snowflake_conn_id
self.poll_interval = poll_interval
self.statement_count = statement_count
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
if self.__class__.__base__.__name__ != "SnowflakeOperator": # type: ignore[union-attr]
# It's better to do str check of the parent class name because currently SnowflakeOperator
# is deprecated and in future OSS SnowflakeOperator may be removed
if any(
[warehouse, database, role, schema, authenticator, session_parameters]
): # pragma: no cover
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
kwargs["hook_params"] = {
"warehouse": warehouse,
"database": database,
"role": role,
"schema": schema,
"authenticator": authenticator,
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
else:
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""
Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
"""
self.log.info("Executing: %s", self.sql)
hook = SnowflakeSqlApiHookAsync(
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
)
hook.execute_query(self.sql, statement_count=self.statement_count, bindings=self.bindings)
self.query_ids = hook.query_ids
self.log.info("List of query ids %s", self.query_ids)

if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)

succeeded_query_ids = []
for query_id in self.query_ids:
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = hook.get_request_url_header_params(query_id)
with requests.session() as session:
session.headers = header
with session.get(url, params=params) as resp:
event = hook.process_query_status_response(resp.json(), resp.status_code)
if resp.status_code == 202:
break
elif resp.status_code == 200:
succeeded_query_ids.append(query_id)
else:
raise AirflowException(f"{event['status']}: {event['message']}")

if len(self.query_ids) == len(succeeded_query_ids):
self.log.info("%s completed successfully.", self.task_id)
return

self.defer(
timeout=self.execution_timeout,
trigger=SnowflakeSqlApiTrigger(
poll_interval=self.poll_interval,
query_ids=self.query_ids,
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
def __init__(self, *args: Any, **kwargs: Any):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` "
"and set `deferrable` param to `True` instead."
),
method_name="execute_complete",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
raise AirflowException(f"{event['status']}: {event['message']}")
elif "status" in event and event["status"] == "success":
hook = SnowflakeSqlApiHookAsync(snowflake_conn_id=self.snowflake_conn_id)
query_ids = typing.cast(List[str], event["statement_query_ids"])
hook.check_query_output(query_ids)
self.log.info("%s completed successfully.", self.task_id)
else:
self.log.info("%s completed successfully.", self.task_id)
super().__init__(*args, deferrable=True, **kwargs)
19 changes: 11 additions & 8 deletions astronomer/providers/snowflake/triggers/snowflake_trigger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from datetime import timedelta
from typing import Any, AsyncIterator

Expand Down Expand Up @@ -81,14 +82,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

class SnowflakeSqlApiTrigger(BaseTrigger):
"""
SnowflakeSqlApi Trigger inherits from the BaseTrigger,it is fired as
deferred class with params to run the task in trigger worker and
fetch the status for the query ids passed
:param task_id: Reference to task id of the Dag
:param poll_interval: polling period in seconds to check for the status
:param query_ids: List of Query ids to run and poll for the status
:param snowflake_conn_id: Reference to Snowflake connection id
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead.
"""

def __init__(
Expand All @@ -99,6 +94,14 @@ def __init__(
token_life_time: timedelta,
token_renewal_delta: timedelta,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use `airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.poll_interval = poll_interval
self.query_ids = query_ids
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ sftp =
apache-airflow-providers-sftp
asyncssh>=2.12.0
snowflake =
apache-airflow-providers-snowflake
apache-airflow-providers-snowflake>=5.3.0
snowflake-sqlalchemy>=1.4.4 # Temporary solution for https://github.com/astronomer/astronomer-providers/issues/958, we should pin apache-airflow-providers-snowflake version after it pins this package to great than or equal to 1.4.4.
# If in future we move Openlineage extractors out of the repo, this dependency should be removed
openlineage =
Expand Down Expand Up @@ -125,7 +125,7 @@ all =
apache-airflow-providers-databricks>=6.1.0
apache-airflow-providers-google>=10.14.0
apache-airflow-providers-http
apache-airflow-providers-snowflake
apache-airflow-providers-snowflake>=5.3.0
apache-airflow-providers-sftp
apache-airflow-providers-microsoft-azure>=8.5.1
asyncssh>=2.12.0
Expand Down
Loading

0 comments on commit 6167579

Please sign in to comment.