Skip to content

Commit

Permalink
Make RedshiftDataOperator handle multiple queries (apache#42900)
Browse files Browse the repository at this point in the history
* Adding updates to RedshiftDataOperator and unit tests

* Adding updates to RedshiftDataOperator and unit tests

* Updating how statement_id is used

* Ruff formatting

* Fixing nit
  • Loading branch information
jroachgolf84 authored and harjeevanmaan committed Oct 23, 2024
1 parent 7dae54a commit 204b559
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef
from mypy_boto3_redshift_data.type_defs import (
DescribeStatementResponseTypeDef,
GetStatementResultResponseTypeDef,
)

from airflow.utils.context import Context

Expand All @@ -37,7 +40,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
"""
Executes SQL Statements against an Amazon Redshift cluster using Redshift Data.
.. seealso::
... see also::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftDataOperator`
Expand Down Expand Up @@ -84,7 +87,6 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
statement_id: str | None

def __init__(
self,
Expand Down Expand Up @@ -124,12 +126,11 @@ def __init__(
poll_interval,
)
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable
self.session_id = session_id
self.session_keep_alive_seconds = session_keep_alive_seconds

def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
def execute(self, context: Context) -> list[GetStatementResultResponseTypeDef] | list[str]:
"""Execute a statement against Amazon Redshift."""
self.log.info("Executing statement: %s", self.sql)

Expand All @@ -154,13 +155,14 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
session_keep_alive_seconds=self.session_keep_alive_seconds,
)

self.statement_id = query_execution_output.statement_id
# Pull the statement ID, session ID
self.statement_id: str = query_execution_output.statement_id

if query_execution_output.session_id:
self.xcom_push(context, key="session_id", value=query_execution_output.session_id)

if self.deferrable and self.wait_for_completion:
is_finished = self.hook.check_query_is_finished(self.statement_id)
is_finished: bool = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
self.defer(
timeout=self.execution_timeout,
Expand All @@ -176,16 +178,13 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
method_name="execute_complete",
)

if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
return result
else:
return self.statement_id
# Use the get_sql_results method to return the results of the SQL query, or the statement_ids,
# depending on the value of self.return_sql_result
return self.get_sql_results(statement_id=self.statement_id, return_sql_result=self.return_sql_result)

def execute_complete(
self, context: Context, event: dict[str, Any] | None = None
) -> GetStatementResultResponseTypeDef | str:
) -> list[GetStatementResultResponseTypeDef] | list[str]:
event = validate_execute_complete_event(event)

if event["status"] == "error":
Expand All @@ -197,16 +196,40 @@ def execute_complete(
raise AirflowException("statement_id should not be empty.")

self.log.info("%s completed successfully.", self.task_id)
if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=statement_id)
self.log.debug("Statement result: %s", result)
return result

return statement_id
# Use the get_sql_results method to return the results of the SQL query, or the statement_ids,
# depending on the value of self.return_sql_result
return self.get_sql_results(statement_id=statement_id, return_sql_result=self.return_sql_result)

def get_sql_results(
self, statement_id: str, return_sql_result: bool
) -> list[GetStatementResultResponseTypeDef] | list[str]:
"""
Retrieve either the result of the SQL query, or the statement ID(s).
:param statement_id: Statement ID of the running queries
:param return_sql_result: Boolean, true if results should be returned
"""
# ISSUE-40427: Pull the statement, and check to see if there are sub-statements. If that is the
# case, pull each of the sub-statement ID's, and grab the results. Otherwise, just use statement_id
statement: DescribeStatementResponseTypeDef = self.hook.conn.describe_statement(Id=statement_id)
statement_ids: list[str] = (
[sub_statement["Id"] for sub_statement in statement["SubStatements"]]
if len(statement.get("SubStatements", [])) > 0
else [statement_id]
)

# If returning the SQL result, use get_statement_result to return the records for each query
if return_sql_result:
results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids]
self.log.debug("Statement result(s): %s", results)
return results
else:
return statement_ids

def on_kill(self) -> None:
"""Cancel the submitted redshift query."""
if self.statement_id:
if hasattr(self, "statement_id"):
self.log.info("Received a kill signal.")
self.log.info("Stopping Query with statementId - %s", self.statement_id)

Expand Down
95 changes: 66 additions & 29 deletions providers/tests/amazon/aws/operators/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,18 @@ def test_init(self):
assert op.hook._config is None

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute(self, mock_conn, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)

operator = RedshiftDataOperator(
Expand All @@ -113,8 +115,11 @@ def test_execute(self, mock_exec_query):
wait_for_completion=True,
poll_interval=poll_interval,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_result = operator.execute({"ti": mock_ti})

mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
Expand All @@ -125,24 +130,31 @@ def test_execute(self, mock_exec_query):
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
wait_for_completion=True, # Matches above
poll_interval=poll_interval,
session_id=None,
session_keep_alive_seconds=None,
)

# Check that the result returned is a list of the statement_id's
assert actual_result == [STATEMENT_ID]
mock_ti.xcom_push.assert_not_called()

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_with_workgroup_name(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_workgroup_name(self, mock_conn, mock_exec_query):
cluster_identifier = None
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

# Like before, return a statement ID and a status
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None)

operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
Expand All @@ -157,26 +169,32 @@ def test_execute_with_workgroup_name(self, mock_exec_query):
wait_for_completion=True,
poll_interval=poll_interval,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_result = operator.execute({"ti": mock_ti})

# Assertions
assert actual_result == [STATEMENT_ID]
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
workgroup_name=workgroup_name, # Called with workgroup_name
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
wait_for_completion=True,
poll_interval=poll_interval,
session_id=None,
session_keep_alive_seconds=None,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_new_session(self, mock_exec_query):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_new_session(self, mock_conn, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
Expand All @@ -186,6 +204,9 @@ def test_execute_new_session(self, mock_exec_query):
poll_interval = 5
wait_for_completion = True

# Like before, return a statement ID and a status
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)

operator = RedshiftDataOperator(
Expand All @@ -203,8 +224,10 @@ def test_execute_new_session(self, mock_exec_query):
session_keep_alive_seconds=123,
)

# Mock the TaskInstance and call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})

mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
Expand Down Expand Up @@ -256,14 +279,17 @@ def test_on_kill_with_query(self, mock_conn):

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_return_sql_result(self, mock_conn):
expected_result = {"Result": True}
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_conn.get_statement_result.return_value = expected_result
expected_result = [{"Result": True}]
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"

# Mock the conn object
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_conn.get_statement_result.return_value = {"Result": True}

operator = RedshiftDataOperator(
task_id=TASK_ID,
cluster_identifier=cluster_identifier,
Expand All @@ -275,8 +301,11 @@ def test_return_sql_result(self, mock_conn):
aws_conn_id=CONN_ID,
return_sql_result=True,
)

# Mock the TaskInstance, run the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
actual_result = operator.execute({"ti": mock_ti})

assert actual_result == expected_result
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
Expand All @@ -287,17 +316,18 @@ def test_return_sql_result(self, mock_conn):
StatementName=statement_name,
WithEvent=False,
)
mock_conn.get_statement_result.assert_called_once_with(
Id=STATEMENT_ID,
)
mock_conn.get_statement_result.assert_called_once_with(Id=STATEMENT_ID)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
@mock.patch(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
return_value=True,
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_finished, mock_defer):
def test_execute_finished_before_defer(
self, mock_exec_query, check_query_is_finished, mock_defer, mock_conn
):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
Expand Down Expand Up @@ -366,25 +396,26 @@ def test_execute_complete_exception(self, deferrable_operator):
deferrable_operator.execute_complete(context=None, event=None)
assert exc.value.args[0] == "Trigger error: event is None"

def test_execute_complete(self, deferrable_operator):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_complete(self, mock_conn, deferrable_operator):
"""Asserts that logging occurs as expected"""

deferrable_operator.statement_id = "uuid"

with mock.patch.object(deferrable_operator.log, "info") as mock_log_info:
assert (
deferrable_operator.execute_complete(
context=None,
event={"status": "success", "message": "Job completed", "statement_id": "uuid"},
)
== "uuid"
)
assert deferrable_operator.execute_complete(
context=None,
event={"status": "success", "message": "Job completed", "statement_id": "uuid"},
) == ["uuid"]
mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)

@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finished, mock_defer):
def test_no_wait_for_completion(
self, mock_exec_query, mock_conn, mock_check_query_is_finished, mock_defer
):
"""Tests that the operator does not check for completion nor defers when wait_for_completion is False,
no matter the value of deferrable"""
cluster_identifier = "cluster_identifier"
Expand All @@ -393,9 +424,12 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5

wait_for_completion = False

# Mock the describe_statement call
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}
mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID)

for deferrable in [True, False]:
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
Expand All @@ -411,11 +445,14 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis
poll_interval=poll_interval,
deferrable=deferrable,
)

# Mock the TaskInstance, call the execute method
mock_ti = mock.MagicMock(name="MockedTaskInstance")
operator.execute({"ti": mock_ti})
actual_results = operator.execute({"ti": mock_ti})

assert not mock_check_query_is_finished.called
assert not mock_defer.called
assert actual_results == [STATEMENT_ID]

def test_template_fields(self):
operator = RedshiftDataOperator(
Expand Down
Loading

0 comments on commit 204b559

Please sign in to comment.