Skip to content

Commit

Permalink
fix: select_query should have precedence over default query in Redshi…
Browse files Browse the repository at this point in the history
…ftToS3Operator

Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed Aug 21, 2024
1 parent 23e9716 commit 11e50d6
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 7 deletions.
20 changes: 13 additions & 7 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ class RedshiftToS3Operator(BaseOperator):
:param s3_bucket: reference to a specific S3 bucket
:param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set
to False, this param must include the desired file name
:param schema: reference to a specific schema in redshift database
Applicable when ``table`` param provided.
:param table: reference to a specific table in redshift database
Used when ``select_query`` param not provided.
:param select_query: custom select query to fetch data from redshift database
:param schema: reference to a specific schema in redshift database,
used when ``table`` param provided and ``select_query`` param not provided
:param table: reference to a specific table in redshift database,
used when ``schema`` param provided and ``select_query`` param not provided
:param select_query: custom select query to fetch data from redshift database,
has precedence over default query `SELECT * FROM ``schema``.``table``
:param redshift_conn_id: reference to a specific redshift database
:param aws_conn_id: reference to a specific S3 connection
If the AWS connection contains 'aws_iam_role' in ``extras``
Expand Down Expand Up @@ -138,12 +139,17 @@ def _build_unload_query(
{unload_options};
"""

@property
def default_select_query(self) -> str | None:
if self.schema and self.table:
return f"SELECT * FROM {self.schema}.{self.table}"
return None

def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
self.s3_key = f"{self.s3_key}/{self.table}_"

if self.schema and self.table:
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
self.select_query = self.select_query or self.default_select_query

if self.select_query is None:
raise ValueError(
Expand Down
117 changes: 117 additions & 0 deletions tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,123 @@ def test_custom_select_query_unloading_with_single_quotes(
assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query)
assert f"UNLOAD ($${expected_query}$$)" in unload_query

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
def test_custom_select_query_has_precedence_over_table_and_schema(
self,
mock_run,
mock_session,
mock_connection,
mock_hook,
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
mock_connection.return_value = Connection()
mock_hook.return_value = Connection()
s3_bucket = "bucket"
s3_key = "key"
unload_options = [
"HEADER",
]
select_query = "select column from table"

op = RedshiftToS3Operator(
select_query=select_query,
table="table",
schema="schema",
s3_bucket=s3_bucket,
s3_key=s3_key,
unload_options=unload_options,
include_header=True,
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)

op.execute(None)

unload_options = "\n\t\t\t".join(unload_options)
credentials_block = build_credentials_block(mock_session.return_value)

unload_query = op._build_unload_query(credentials_block, select_query, "key/table_", unload_options)

assert mock_run.call_count == 1
assert access_key in unload_query
assert secret_key in unload_query
assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query)

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
def test_default_select_query_used_when_table_and_schema_missing(
self,
mock_run,
mock_session,
mock_connection,
mock_hook,
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
mock_connection.return_value = Connection()
mock_hook.return_value = Connection()
s3_bucket = "bucket"
s3_key = "key"
unload_options = [
"HEADER",
]
default_query = "SELECT * FROM schema.table"

op = RedshiftToS3Operator(
table="table",
schema="schema",
s3_bucket=s3_bucket,
s3_key=s3_key,
unload_options=unload_options,
include_header=True,
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)

op.execute(None)

unload_options = "\n\t\t\t".join(unload_options)
credentials_block = build_credentials_block(mock_session.return_value)

unload_query = op._build_unload_query(credentials_block, default_query, "key/table_", unload_options)

assert mock_run.call_count == 1
assert access_key in unload_query
assert secret_key in unload_query
assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query)

def test_lack_of_select_query_and_schema_and_table_raises_error(self):
op = RedshiftToS3Operator(
s3_bucket="bucket",
s3_key="key",
include_header=True,
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)

with pytest.raises(ValueError):
op.execute(None)

@pytest.mark.parametrize("table_as_file_name, expected_s3_key", [[True, "key/table_"], [False, "key"]])
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection")
Expand Down

0 comments on commit 11e50d6

Please sign in to comment.