diff --git a/providers/src/airflow/providers/amazon/aws/operators/redshift_data.py b/providers/src/airflow/providers/amazon/aws/operators/redshift_data.py index 3d00c6d22edf7..0813918854464 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/providers/src/airflow/providers/amazon/aws/operators/redshift_data.py @@ -28,7 +28,10 @@ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: - from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef + from mypy_boto3_redshift_data.type_defs import ( + DescribeStatementResponseTypeDef, + GetStatementResultResponseTypeDef, + ) from airflow.utils.context import Context @@ -37,7 +40,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): """ Executes SQL Statements against an Amazon Redshift cluster using Redshift Data. - .. seealso:: + ... see also:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:RedshiftDataOperator` @@ -84,7 +87,6 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): ) template_ext = (".sql",) template_fields_renderers = {"sql": "sql"} - statement_id: str | None def __init__( self, @@ -124,12 +126,11 @@ def __init__( poll_interval, ) self.return_sql_result = return_sql_result - self.statement_id: str | None = None self.deferrable = deferrable self.session_id = session_id self.session_keep_alive_seconds = session_keep_alive_seconds - def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: + def execute(self, context: Context) -> list[GetStatementResultResponseTypeDef] | list[str]: """Execute a statement against Amazon Redshift.""" self.log.info("Executing statement: %s", self.sql) @@ -154,13 +155,14 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: session_keep_alive_seconds=self.session_keep_alive_seconds, ) - self.statement_id = query_execution_output.statement_id + # Pull the statement ID, session ID + self.statement_id: str = query_execution_output.statement_id if query_execution_output.session_id: self.xcom_push(context, key="session_id", value=query_execution_output.session_id) if self.deferrable and self.wait_for_completion: - is_finished = self.hook.check_query_is_finished(self.statement_id) + is_finished: bool = self.hook.check_query_is_finished(self.statement_id) if not is_finished: self.defer( timeout=self.execution_timeout, @@ -176,16 +178,13 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: method_name="execute_complete", ) - if self.return_sql_result: - result = self.hook.conn.get_statement_result(Id=self.statement_id) - self.log.debug("Statement result: %s", result) - return result - else: - return self.statement_id + # Use the get_sql_results method to return the results of the SQL query, or the statement_ids, + # depending on the value of self.return_sql_result + return self.get_sql_results(statement_id=self.statement_id, return_sql_result=self.return_sql_result) def execute_complete( self, context: Context, event: dict[str, Any] | None = None - ) -> GetStatementResultResponseTypeDef | str: + ) -> list[GetStatementResultResponseTypeDef] | list[str]: event = validate_execute_complete_event(event) if event["status"] == "error": @@ -197,16 +196,40 @@ def execute_complete( raise AirflowException("statement_id should not be empty.") self.log.info("%s completed successfully.", self.task_id) - if self.return_sql_result: - result = self.hook.conn.get_statement_result(Id=statement_id) - self.log.debug("Statement result: %s", result) - return result - return statement_id + # Use the get_sql_results method to return the results of the SQL query, or the statement_ids, + # depending on the value of self.return_sql_result + return self.get_sql_results(statement_id=statement_id, return_sql_result=self.return_sql_result) + + def get_sql_results( + self, statement_id: str, return_sql_result: bool + ) -> list[GetStatementResultResponseTypeDef] | list[str]: + """ + Retrieve either the result of the SQL query, or the statement ID(s). + + :param statement_id: Statement ID of the running queries + :param return_sql_result: Boolean, true if results should be returned + """ + # ISSUE-40427: Pull the statement, and check to see if there are sub-statements. If that is the + # case, pull each of the sub-statement ID's, and grab the results. Otherwise, just use statement_id + statement: DescribeStatementResponseTypeDef = self.hook.conn.describe_statement(Id=statement_id) + statement_ids: list[str] = ( + [sub_statement["Id"] for sub_statement in statement["SubStatements"]] + if len(statement.get("SubStatements", [])) > 0 + else [statement_id] + ) + + # If returning the SQL result, use get_statement_result to return the records for each query + if return_sql_result: + results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids] + self.log.debug("Statement result(s): %s", results) + return results + else: + return statement_ids def on_kill(self) -> None: """Cancel the submitted redshift query.""" - if self.statement_id: + if hasattr(self, "statement_id"): self.log.info("Received a kill signal.") self.log.info("Stopping Query with statementId - %s", self.statement_id) diff --git a/providers/tests/amazon/aws/operators/test_redshift_data.py b/providers/tests/amazon/aws/operators/test_redshift_data.py index d367fc3ca9596..da9a486df43af 100644 --- a/providers/tests/amazon/aws/operators/test_redshift_data.py +++ b/providers/tests/amazon/aws/operators/test_redshift_data.py @@ -88,7 +88,8 @@ def test_init(self): assert op.hook._config is None @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_execute(self, mock_exec_query): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute(self, mock_conn, mock_exec_query): cluster_identifier = "cluster_identifier" workgroup_name = None db_user = "db_user" @@ -96,8 +97,9 @@ def test_execute(self, mock_exec_query): statement_name = "statement_name" parameters = [{"name": "id", "value": "1"}] poll_interval = 5 - wait_for_completion = True + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None) operator = RedshiftDataOperator( @@ -113,8 +115,11 @@ def test_execute(self, mock_exec_query): wait_for_completion=True, poll_interval=poll_interval, ) + + # Mock the TaskInstance, call the execute method mock_ti = mock.MagicMock(name="MockedTaskInstance") - operator.execute({"ti": mock_ti}) + actual_result = operator.execute({"ti": mock_ti}) + mock_exec_query.assert_called_once_with( sql=SQL, database=DATABASE, @@ -125,16 +130,19 @@ def test_execute(self, mock_exec_query): statement_name=statement_name, parameters=parameters, with_event=False, - wait_for_completion=wait_for_completion, + wait_for_completion=True, # Matches above poll_interval=poll_interval, session_id=None, session_keep_alive_seconds=None, ) + # Check that the result returned is a list of the statement_id's + assert actual_result == [STATEMENT_ID] mock_ti.xcom_push.assert_not_called() @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_execute_with_workgroup_name(self, mock_exec_query): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_with_workgroup_name(self, mock_conn, mock_exec_query): cluster_identifier = None workgroup_name = "workgroup_name" db_user = "db_user" @@ -142,7 +150,11 @@ def test_execute_with_workgroup_name(self, mock_exec_query): statement_name = "statement_name" parameters = [{"name": "id", "value": "1"}] poll_interval = 5 - wait_for_completion = True + + # Like before, return a statement ID and a status + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=None) operator = RedshiftDataOperator( aws_conn_id=CONN_ID, @@ -157,26 +169,32 @@ def test_execute_with_workgroup_name(self, mock_exec_query): wait_for_completion=True, poll_interval=poll_interval, ) + + # Mock the TaskInstance, call the execute method mock_ti = mock.MagicMock(name="MockedTaskInstance") - operator.execute({"ti": mock_ti}) + actual_result = operator.execute({"ti": mock_ti}) + + # Assertions + assert actual_result == [STATEMENT_ID] mock_exec_query.assert_called_once_with( sql=SQL, database=DATABASE, cluster_identifier=cluster_identifier, - workgroup_name=workgroup_name, + workgroup_name=workgroup_name, # Called with workgroup_name db_user=db_user, secret_arn=secret_arn, statement_name=statement_name, parameters=parameters, with_event=False, - wait_for_completion=wait_for_completion, + wait_for_completion=True, poll_interval=poll_interval, session_id=None, session_keep_alive_seconds=None, ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_execute_new_session(self, mock_exec_query): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_new_session(self, mock_conn, mock_exec_query): cluster_identifier = "cluster_identifier" workgroup_name = None db_user = "db_user" @@ -186,6 +204,9 @@ def test_execute_new_session(self, mock_exec_query): poll_interval = 5 wait_for_completion = True + # Like before, return a statement ID and a status + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID) operator = RedshiftDataOperator( @@ -203,8 +224,10 @@ def test_execute_new_session(self, mock_exec_query): session_keep_alive_seconds=123, ) + # Mock the TaskInstance and call the execute method mock_ti = mock.MagicMock(name="MockedTaskInstance") operator.execute({"ti": mock_ti}) + mock_exec_query.assert_called_once_with( sql=SQL, database=DATABASE, @@ -256,14 +279,17 @@ def test_on_kill_with_query(self, mock_conn): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_return_sql_result(self, mock_conn): - expected_result = {"Result": True} - mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID} - mock_conn.describe_statement.return_value = {"Status": "FINISHED"} - mock_conn.get_statement_result.return_value = expected_result + expected_result = [{"Result": True}] cluster_identifier = "cluster_identifier" db_user = "db_user" secret_arn = "secret_arn" statement_name = "statement_name" + + # Mock the conn object + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID, "SessionId": SESSION_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + mock_conn.get_statement_result.return_value = {"Result": True} + operator = RedshiftDataOperator( task_id=TASK_ID, cluster_identifier=cluster_identifier, @@ -275,8 +301,11 @@ def test_return_sql_result(self, mock_conn): aws_conn_id=CONN_ID, return_sql_result=True, ) + + # Mock the TaskInstance, run the execute method mock_ti = mock.MagicMock(name="MockedTaskInstance") actual_result = operator.execute({"ti": mock_ti}) + assert actual_result == expected_result mock_conn.execute_statement.assert_called_once_with( Database=DATABASE, @@ -287,17 +316,18 @@ def test_return_sql_result(self, mock_conn): StatementName=statement_name, WithEvent=False, ) - mock_conn.get_statement_result.assert_called_once_with( - Id=STATEMENT_ID, - ) + mock_conn.get_statement_result.assert_called_once_with(Id=STATEMENT_ID) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") @mock.patch( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", return_value=True, ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_finished, mock_defer): + def test_execute_finished_before_defer( + self, mock_exec_query, check_query_is_finished, mock_defer, mock_conn + ): cluster_identifier = "cluster_identifier" workgroup_name = None db_user = "db_user" @@ -366,25 +396,26 @@ def test_execute_complete_exception(self, deferrable_operator): deferrable_operator.execute_complete(context=None, event=None) assert exc.value.args[0] == "Trigger error: event is None" - def test_execute_complete(self, deferrable_operator): + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_complete(self, mock_conn, deferrable_operator): """Asserts that logging occurs as expected""" deferrable_operator.statement_id = "uuid" with mock.patch.object(deferrable_operator.log, "info") as mock_log_info: - assert ( - deferrable_operator.execute_complete( - context=None, - event={"status": "success", "message": "Job completed", "statement_id": "uuid"}, - ) - == "uuid" - ) + assert deferrable_operator.execute_complete( + context=None, + event={"status": "success", "message": "Job completed", "statement_id": "uuid"}, + ) == ["uuid"] mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finished, mock_defer): + def test_no_wait_for_completion( + self, mock_exec_query, mock_conn, mock_check_query_is_finished, mock_defer + ): """Tests that the operator does not check for completion nor defers when wait_for_completion is False, no matter the value of deferrable""" cluster_identifier = "cluster_identifier" @@ -393,9 +424,12 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis statement_name = "statement_name" parameters = [{"name": "id", "value": "1"}] poll_interval = 5 - wait_for_completion = False + # Mock the describe_statement call + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + mock_exec_query.return_value = QueryExecutionOutput(statement_id=STATEMENT_ID, session_id=SESSION_ID) + for deferrable in [True, False]: operator = RedshiftDataOperator( aws_conn_id=CONN_ID, @@ -411,11 +445,14 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis poll_interval=poll_interval, deferrable=deferrable, ) + + # Mock the TaskInstance, call the execute method mock_ti = mock.MagicMock(name="MockedTaskInstance") - operator.execute({"ti": mock_ti}) + actual_results = operator.execute({"ti": mock_ti}) assert not mock_check_query_is_finished.called assert not mock_defer.called + assert actual_results == [STATEMENT_ID] def test_template_fields(self): operator = RedshiftDataOperator( diff --git a/providers/tests/system/amazon/aws/example_redshift.py b/providers/tests/system/amazon/aws/example_redshift.py index 146a20da451ba..1d8784e4176cd 100644 --- a/providers/tests/system/amazon/aws/example_redshift.py +++ b/providers/tests/system/amazon/aws/example_redshift.py @@ -146,13 +146,15 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=""" + sql=[ + """ CREATE TABLE IF NOT EXISTS fruit ( fruit_id INTEGER, name VARCHAR NOT NULL, color VARCHAR NOT NULL ); - """, + """ + ], poll_interval=POLL_INTERVAL, wait_for_completion=True, ) @@ -163,14 +165,14 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=""" - INSERT INTO fruit VALUES ( 1, 'Banana', 'Yellow'); - INSERT INTO fruit VALUES ( 2, 'Apple', 'Red'); - INSERT INTO fruit VALUES ( 3, 'Lemon', 'Yellow'); - INSERT INTO fruit VALUES ( 4, 'Grape', 'Purple'); - INSERT INTO fruit VALUES ( 5, 'Pear', 'Green'); - INSERT INTO fruit VALUES ( 6, 'Strawberry', 'Red'); - """, + sql=[ + "INSERT INTO fruit VALUES ( 1, 'Banana', 'Yellow');", + "INSERT INTO fruit VALUES ( 2, 'Apple', 'Red');", + "INSERT INTO fruit VALUES ( 3, 'Lemon', 'Yellow');", + "INSERT INTO fruit VALUES ( 4, 'Grape', 'Purple');", + "INSERT INTO fruit VALUES ( 5, 'Pear', 'Green');", + "INSERT INTO fruit VALUES ( 6, 'Strawberry', 'Red');", + ], poll_interval=POLL_INTERVAL, wait_for_completion=True, ) @@ -181,13 +183,15 @@ cluster_identifier=redshift_cluster_identifier, database=DB_NAME, db_user=DB_LOGIN, - sql=""" + sql=[ + """ CREATE TEMPORARY TABLE tmp_people ( id INTEGER, first_name VARCHAR(100), age INTEGER ); - """, + """ + ], poll_interval=POLL_INTERVAL, wait_for_completion=True, session_keep_alive_seconds=600, @@ -195,11 +199,11 @@ insert_data_reuse_session = RedshiftDataOperator( task_id="insert_data_reuse_session", - sql=""" - INSERT INTO tmp_people VALUES ( 1, 'Bob', 30); - INSERT INTO tmp_people VALUES ( 2, 'Alice', 35); - INSERT INTO tmp_people VALUES ( 3, 'Charlie', 40); - """, + sql=[ + "INSERT INTO tmp_people VALUES ( 1, 'Bob', 30);", + "INSERT INTO tmp_people VALUES ( 2, 'Alice', 35);", + "INSERT INTO tmp_people VALUES ( 3, 'Charlie', 40);", + ], poll_interval=POLL_INTERVAL, wait_for_completion=True, session_id="{{ task_instance.xcom_pull(task_ids='create_tmp_table_data_api', key='session_id') }}",