diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 73578ea539b71..ef3cebdae9838 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -44,11 +44,12 @@ class RedshiftToS3Operator(BaseOperator): :param s3_bucket: reference to a specific S3 bucket :param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set to False, this param must include the desired file name - :param schema: reference to a specific schema in redshift database - Applicable when ``table`` param provided. - :param table: reference to a specific table in redshift database - Used when ``select_query`` param not provided. - :param select_query: custom select query to fetch data from redshift database + :param schema: reference to a specific schema in redshift database, + used when ``table`` param provided and ``select_query`` param not provided + :param table: reference to a specific table in redshift database, + used when ``schema`` param provided and ``select_query`` param not provided + :param select_query: custom select query to fetch data from redshift database, + has precedence over default query `SELECT * FROM ``schema``.``table`` :param redshift_conn_id: reference to a specific redshift database :param aws_conn_id: reference to a specific S3 connection If the AWS connection contains 'aws_iam_role' in ``extras`` @@ -138,12 +139,17 @@ def _build_unload_query( {unload_options}; """ + @property + def default_select_query(self) -> str | None: + if self.schema and self.table: + return f"SELECT * FROM {self.schema}.{self.table}" + return None + def execute(self, context: Context) -> None: if self.table and self.table_as_file_name: self.s3_key = f"{self.s3_key}/{self.table}_" - if self.schema and self.table: - self.select_query = f"SELECT * FROM {self.schema}.{self.table}" + self.select_query = self.select_query or self.default_select_query if self.select_query is None: raise ValueError( diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py index d2af90a445e29..2d28acd22e7e6 100644 --- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py @@ -305,6 +305,123 @@ def test_custom_select_query_unloading_with_single_quotes( assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query) assert f"UNLOAD ($${expected_query}$$)" in unload_query + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_custom_select_query_has_precedence_over_table_and_schema( + self, + mock_run, + mock_session, + mock_connection, + mock_hook, + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + mock_connection.return_value = Connection() + mock_hook.return_value = Connection() + s3_bucket = "bucket" + s3_key = "key" + unload_options = [ + "HEADER", + ] + select_query = "select column from table" + + op = RedshiftToS3Operator( + select_query=select_query, + table="table", + schema="schema", + s3_bucket=s3_bucket, + s3_key=s3_key, + unload_options=unload_options, + include_header=True, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + + op.execute(None) + + unload_options = "\n\t\t\t".join(unload_options) + credentials_block = build_credentials_block(mock_session.return_value) + + unload_query = op._build_unload_query(credentials_block, select_query, "key/table_", unload_options) + + assert mock_run.call_count == 1 + assert access_key in unload_query + assert secret_key in unload_query + assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query) + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_default_select_query_used_when_table_and_schema_missing( + self, + mock_run, + mock_session, + mock_connection, + mock_hook, + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + mock_connection.return_value = Connection() + mock_hook.return_value = Connection() + s3_bucket = "bucket" + s3_key = "key" + unload_options = [ + "HEADER", + ] + default_query = "SELECT * FROM schema.table" + + op = RedshiftToS3Operator( + table="table", + schema="schema", + s3_bucket=s3_bucket, + s3_key=s3_key, + unload_options=unload_options, + include_header=True, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + + op.execute(None) + + unload_options = "\n\t\t\t".join(unload_options) + credentials_block = build_credentials_block(mock_session.return_value) + + unload_query = op._build_unload_query(credentials_block, default_query, "key/table_", unload_options) + + assert mock_run.call_count == 1 + assert access_key in unload_query + assert secret_key in unload_query + assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query) + + def test_lack_of_select_query_and_schema_and_table_raises_error(self): + op = RedshiftToS3Operator( + s3_bucket="bucket", + s3_key="key", + include_header=True, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + + with pytest.raises(ValueError): + op.execute(None) + @pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True, "key/table_"], [False, "key"]]) @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection")