Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove part of openlineage extraction from S3ToRedshiftOperator #41631

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 8 additions & 35 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class S3ToRedshiftOperator(BaseOperator):
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param column_list: list of column names to load
:param column_list: list of column names to load source data fields into specific target columns
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-column-mapping.html#copy-column-list
:param copy_options: reference to a list of COPY options
:param method: Action to be performed on execution. Available ``APPEND``, ``UPSERT`` and ``REPLACE``.
:param upsert_keys: List of fields to use as key on upsert action
Expand Down Expand Up @@ -204,18 +205,13 @@ def execute(self, context: Context) -> None:

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will query destination table."""
from pathlib import Path

from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
get_identity_column_lineage_facet,
)
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Identifier,
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
SymlinksDatasetFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

Expand All @@ -235,36 +231,8 @@ def get_openlineage_facets_on_complete(self, task_instance):
database = redshift_sql_hook.conn.schema
authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority
output_dataset_facets = get_facets_from_redshift_table(
redshift_sql_hook, self.table, self.redshift_data_api_kwargs, self.schema
)

input_dataset_facets = {}
if not self.column_list:
# If column_list is not specified, then we know that input file matches columns of output table.
input_dataset_facets["schema"] = output_dataset_facets["schema"]

dataset_name = self.s3_key
if "*" in dataset_name:
# If wildcard ("*") is used in s3 path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
input_dataset_facets["symlink"] = SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", name=dataset_name, type="file")]
redshift_sql_hook, self.table, {}, self.schema
)
dataset_name = Path(dataset_name).parent.as_posix()
if dataset_name == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
dataset_name = "/"

input_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=dataset_name,
facets=input_dataset_facets,
)

output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet(
field_names=[field.name for field in output_dataset_facets["schema"].fields],
input_datasets=[input_dataset],
)

if self.method == "REPLACE":
output_dataset_facets["lifecycleStateChange"] = LifecycleStateChangeDatasetFacet(
Expand All @@ -277,4 +245,9 @@ def get_openlineage_facets_on_complete(self, task_instance):
facets=output_dataset_facets,
)

input_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=self.s3_key,
)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
144 changes: 67 additions & 77 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from airflow.exceptions import AirflowException
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
from airflow.providers.common.compat.openlineage.facet import LifecycleStateChange
from airflow.providers.common.compat.openlineage.facet import (
DocumentationDatasetFacet,
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
)
from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces


Expand Down Expand Up @@ -502,8 +508,9 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_default(
self, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
Expand All @@ -515,6 +522,11 @@ def test_get_openlineage_facets_on_complete_default(
mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -531,33 +543,30 @@ def test_get_openlineage_facets_on_complete_default(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
# Hook called two times - on operator execution, and on querying data in redshift to fetch schema
assert mock_run.call_count == 2
# Hook called only one time - on operator execution - we mocked querying to fetch schema
assert mock_run.call_count == 1

assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.outputs[0].facets.get("columnLineage") is not None

assert lineage.inputs[0].facets.get("schema") is not None
# As method was not overwrite, there should be no lifecycleStateChange facet
assert "lifecycleStateChange" not in lineage.outputs[0].facets
assert lineage.outputs[0].facets == mock_facets
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_replace(
self, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
Expand All @@ -569,6 +578,11 @@ def test_get_openlineage_facets_on_complete_replace(
mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -586,59 +600,25 @@ def test_get_openlineage_facets_on_complete_replace(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)

assert (
lineage.outputs[0].facets["lifecycleStateChange"].lifecycleStateChange
== LifecycleStateChange.OVERWRITE
)

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
def test_get_openlineage_facets_on_complete_column_list(
self, mock_run, mock_session, mock_connection, mock_hook
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None

mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)

schema = "schema"
table = "table"
s3_bucket = "bucket"
s3_key = "key"
copy_options = ""

op = S3ToRedshiftOperator(
schema=schema,
table=table,
s3_bucket=s3_bucket,
s3_key=s3_key,
copy_options=copy_options,
column_list=["column1", "column2"],
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.inputs[0].facets.get("schema") is None
assert lineage.outputs[0].facets == {
**mock_facets,
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.OVERWRITE
),
}
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
Expand All @@ -648,8 +628,9 @@ def test_get_openlineage_facets_on_complete_column_list(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
new_callable=mock.PropertyMock,
)
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_using_redshift_data_api(
self, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook
):
"""
Using the Redshift Data API instead of the SQL-based connection
Expand All @@ -666,6 +647,11 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}

mock_rs_region.return_value = "region"
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand All @@ -689,7 +675,7 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
method="REPLACE",
redshift_data_api_kwargs=dict(
database=database,
cluster_identifier=cluster_identifier,
Expand All @@ -705,15 +691,17 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0].name == s3_key
assert lineage.inputs[0].namespace == f"s3://{s3_bucket}"
assert lineage.outputs[0].name == f"database.{schema}.{table}"
assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"

assert lineage.outputs[0].facets.get("schema") is not None
assert lineage.outputs[0].facets.get("columnLineage") is not None

assert lineage.inputs[0].facets.get("schema") is not None
# As method was not overwrite, there should be no lifecycleStateChange facet
assert "lifecycleStateChange" not in lineage.outputs[0].facets
assert lineage.outputs[0].facets == {
**mock_facets,
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.OVERWRITE
),
}
assert lineage.inputs[0].facets == {}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
Expand All @@ -724,8 +712,9 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
new_callable=mock.PropertyMock,
)
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
self, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
):
"""
Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage.
Expand All @@ -745,6 +734,11 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}

mock_rs_region.return_value = "region"
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
}
mock_get_facets.return_value = mock_facets

schema = "schema"
table = "table"
Expand Down Expand Up @@ -794,13 +788,9 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
op_rs_sql.execute(None)
rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None)

assert rs_sql_lineage.inputs == rs_data_lineage.inputs
assert len(rs_sql_lineage.inputs) == 1
assert len(rs_sql_lineage.outputs) == 1
assert len(rs_data_lineage.outputs) == 1
assert rs_sql_lineage.outputs[0].facets["schema"] == rs_data_lineage.outputs[0].facets["schema"]
assert (
rs_sql_lineage.outputs[0].facets["columnLineage"]
== rs_data_lineage.outputs[0].facets["columnLineage"]
)
assert rs_sql_lineage.outputs[0].name == rs_data_lineage.outputs[0].name
assert rs_sql_lineage.outputs[0].namespace == rs_data_lineage.outputs[0].namespace
assert rs_sql_lineage.inputs == rs_data_lineage.inputs
assert rs_sql_lineage.outputs == rs_data_lineage.outputs
assert rs_sql_lineage.job_facets == rs_data_lineage.job_facets
assert rs_sql_lineage.run_facets == rs_data_lineage.run_facets