Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RedshiftDataOperator handle multiple queries #42900

Merged
merged 5 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef
from mypy_boto3_redshift_data.type_defs import (
DescribeStatementResponseTypeDef,
GetStatementResultResponseTypeDef,
)

from airflow.utils.context import Context

Expand All @@ -37,7 +40,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
"""
Executes SQL Statements against an Amazon Redshift cluster using Redshift Data.

.. seealso::
... see also::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftDataOperator`

Expand Down Expand Up @@ -84,7 +87,6 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
statement_id: str | None

def __init__(
self,
Expand Down Expand Up @@ -124,12 +126,11 @@ def __init__(
poll_interval,
)
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable
self.session_id = session_id
self.session_keep_alive_seconds = session_keep_alive_seconds

def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
def execute(self, context: Context) -> list[GetStatementResultResponseTypeDef] | list[str]:
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"""Execute a statement against Amazon Redshift."""
self.log.info("Executing statement: %s", self.sql)

Expand All @@ -154,13 +155,14 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
session_keep_alive_seconds=self.session_keep_alive_seconds,
)

self.statement_id = query_execution_output.statement_id
# Pull the statement ID, session ID
self.statement_id: str = query_execution_output.statement_id

if query_execution_output.session_id:
self.xcom_push(context, key="session_id", value=query_execution_output.session_id)

if self.deferrable and self.wait_for_completion:
is_finished = self.hook.check_query_is_finished(self.statement_id)
is_finished: bool = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
self.defer(
timeout=self.execution_timeout,
Expand All @@ -176,16 +178,13 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
method_name="execute_complete",
)

if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
return result
else:
return self.statement_id
# Use the get_sql_results method to return the results of the SQL query, or the statement_ids,
# depending on the value of self.return_sql_result
return self.get_sql_results(statement_id=self.statement_id, return_sql_result=self.return_sql_result)

def execute_complete(
self, context: Context, event: dict[str, Any] | None = None
) -> GetStatementResultResponseTypeDef | str:
) -> list[GetStatementResultResponseTypeDef] | list[str]:
event = validate_execute_complete_event(event)

if event["status"] == "error":
Expand All @@ -197,16 +196,40 @@ def execute_complete(
raise AirflowException("statement_id should not be empty.")

self.log.info("%s completed successfully.", self.task_id)
if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=statement_id)
self.log.debug("Statement result: %s", result)
return result

return statement_id
# Use the get_sql_results method to return the results of the SQL query, or the statement_ids,
# depending on the value of self.return_sql_result
return self.get_sql_results(statement_id=statement_id, return_sql_result=self.return_sql_result)

def get_sql_results(
self, statement_id: str, return_sql_result: bool
) -> list[GetStatementResultResponseTypeDef] | list[str]:
"""
Retrieve either the result of the SQL query, or the statement ID(s).

:param statement_id: Statement ID of the running queries
:param return_sql_result: Boolean, true if results should be returned
"""
# ISSUE-40427: Pull the statement, and check to see if there are sub-statements. If that is the
# case, pull each of the sub-statement ID's, and grab the results. Otherwise, just use statement_id
statement: DescribeStatementResponseTypeDef = self.hook.conn.describe_statement(Id=statement_id)
statement_ids: list[str] = (
[sub_statement["Id"] for sub_statement in statement["SubStatements"]]
if len(statement.get("SubStatements", [])) > 0
else [statement_id]
)

# If returning the SQL result, use get_statement_result to return the records for each query
if return_sql_result:
results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids]
self.log.debug("Statement result(s): %s", results)
return results
else:
return statement_ids

def on_kill(self) -> None:
"""Cancel the submitted redshift query."""
if self.statement_id:
if hasattr(self, "statement_id"):
self.log.info("Received a kill signal.")
self.log.info("Stopping Query with statementId - %s", self.statement_id)

Expand Down
95 changes: 66 additions & 29 deletions providers/tests/amazon/aws/operators/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,18 @@ def test_init(self):
assert op.hook._config is None

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute(self, mock_conn, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)

operator = RedshiftDataOperator(
Expand All @@ -113,8 +115,11 @@ def test_execute(self, mock_exec_query):
wait_for_completion=True,
poll_interval=poll_interval,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_result = operator.execute({"ti": mock_ti})

mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
Expand All @@ -125,24 +130,31 @@ def test_execute(self, mock_exec_query):
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
wait_for_completion=True, # Matches above
poll_interval=poll_interval,
session_id=None,
session_keep_alive_seconds=None,
)

# Check that the result returned is a list of the statement_id's
assert actual_result == [STATEMENT_ID]
mock_ti.xcom_push.assert_not_called()

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_with_workgroup_name(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_workgroup_name(self, mock_conn, mock_exec_query):
cluster_identifier = None
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

# Like before, return a statement ID and a status
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)

operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
Expand All @@ -157,26 +169,32 @@ def test_execute_with_workgroup_name(self, mock_exec_query):
wait_for_completion=True,
poll_interval=poll_interval,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_result = operator.execute({"ti": mock_ti})

# Assertions
assert actual_result == [STATEMENT_ID]
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
workgroup_name=workgroup_name, # Called with workgroup_name
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
wait_for_completion=True,
poll_interval=poll_interval,
session_id=None,
session_keep_alive_seconds=None,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_new_session(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_new_session(self, mock_conn, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
Expand All @@ -186,6 +204,9 @@ def test_execute_new_session(self, mock_exec_query):
poll_interval = 5
wait_for_completion = True

# Like before, return a statement ID and a status
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)

operator = RedshiftDataOperator(
Expand All @@ -203,8 +224,10 @@ def test_execute_new_session(self, mock_exec_query):
session_keep_alive_seconds=123,
)

# Mock the TaskInstance and call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})

mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
Expand Down Expand Up @@ -256,14 +279,17 @@ def test_on_kill_with_query(self, mock_conn):

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_return_sql_result(self, mock_conn):
expected_result = {"Result": True}
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_conn.get_statement_result.return_value = expected_result
expected_result = [{"Result": True}]
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"

# Mock the conn object
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_conn.get_statement_result.return_value = {"Result": True}

operator = RedshiftDataOperator(
task_id=TASK_ID,
cluster_identifier=cluster_identifier,
Expand All @@ -275,8 +301,11 @@ def test_return_sql_result(self, mock_conn):
aws_conn_id=CONN_ID,
return_sql_result=True,
)

# Mock the TaskInstance, run the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
actual_result = operator.execute({"ti": mock_ti})

assert actual_result == expected_result
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
Expand All @@ -287,17 +316,18 @@ def test_return_sql_result(self, mock_conn):
StatementName=statement_name,
WithEvent=False,
)
mock_conn.get_statement_result.assert_called_once_with(
Id=STATEMENT_ID,
)
mock_conn.get_statement_result.assert_called_once_with(Id=STATEMENT_ID)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
@mock.patch(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
return_value=True,
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_finished, mock_defer):
def test_execute_finished_before_defer(
self, mock_exec_query, check_query_is_finished, mock_defer, mock_conn
):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
Expand Down Expand Up @@ -366,25 +396,26 @@ def test_execute_complete_exception(self, deferrable_operator):
deferrable_operator.execute_complete(context=None, event=None)
assert exc.value.args[0] == "Trigger error: event is None"

def test_execute_complete(self, deferrable_operator):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_complete(self, mock_conn, deferrable_operator):
"""Asserts that logging occurs as expected"""

deferrable_operator.statement_id = "uuid"

with mock.patch.object(deferrable_operator.log, "info") as mock_log_info:
assert (
deferrable_operator.execute_complete(
context=None,
event={"status": "success", "message": "Job completed", "statement_id": "uuid"},
)
== "uuid"
)
assert deferrable_operator.execute_complete(
context=None,
event={"status": "success", "message": "Job completed", "statement_id": "uuid"},
) == ["uuid"]
mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)

@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finished, mock_defer):
def test_no_wait_for_completion(
self, mock_exec_query, mock_conn, mock_check_query_is_finished, mock_defer
):
"""Tests that the operator does not check for completion nor defers when wait_for_completion is False,
no matter the value of deferrable"""
cluster_identifier = "cluster_identifier"
Expand All @@ -393,9 +424,12 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5

wait_for_completion = False

# Mock the describe_statement call
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)

for deferrable in [True, False]:
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
Expand All @@ -411,11 +445,14 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis
poll_interval=poll_interval,
deferrable=deferrable,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_results = operator.execute({"ti": mock_ti})

assert not mock_check_query_is_finished.called
assert not mock_defer.called
assert actual_results == [STATEMENT_ID]

def test_template_fields(self):
operator = RedshiftDataOperator(
Expand Down
Loading