Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for role arn for aws creds in storage transfer job operators #38911

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class GcpTransferOperationStatus:
AWS_ACCESS_KEY = "awsAccessKey"
AWS_SECRET_ACCESS_KEY = "secretAccessKey"
AWS_S3_DATA_SOURCE = "awsS3DataSource"
AWS_ROLE_ARN = "roleArn"
BODY = "body"
BUCKET_NAME = "bucketName"
COUNTERS = "counters"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
ACCESS_KEY_ID,
AWS_ACCESS_KEY,
AWS_ROLE_ARN,
AWS_S3_DATA_SOURCE,
BUCKET_NAME,
DAY,
Expand Down Expand Up @@ -79,15 +80,23 @@ def __init__(
self.default_schedule = default_schedule

def _inject_aws_credentials(self) -> None:
if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
aws_credentials = aws_hook.get_credentials()
aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined]
aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined]
self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
ACCESS_KEY_ID: aws_access_key_id,
SECRET_ACCESS_KEY: aws_secret_access_key,
}
if TRANSFER_SPEC not in self.body:
return

if AWS_S3_DATA_SOURCE not in self.body[TRANSFER_SPEC]:
return

if AWS_ROLE_ARN in self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE]:
return

aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
aws_credentials = aws_hook.get_credentials()
aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined]
aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined]
self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
ACCESS_KEY_ID: aws_access_key_id,
SECRET_ACCESS_KEY: aws_secret_access_key,
}

def _reformat_date(self, field_key: str) -> None:
schedule = self.body[SCHEDULE]
Expand Down Expand Up @@ -819,6 +828,9 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param delete_job_after_completion: If True, delete the job after complete.
If set to True, 'wait' must be set to True.
:param aws_role_arn: Optional AWS role ARN for workload identity federation. This will
override the `aws_conn_id` for authentication between GCP and AWS; see
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#AwsS3Data
"""

template_fields: Sequence[str] = (
Expand All @@ -830,6 +842,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
"description",
"object_conditions",
"google_impersonation_chain",
"aws_role_arn",
)
ui_color = "#e09411"

Expand All @@ -851,6 +864,7 @@ def __init__(
timeout: float | None = None,
google_impersonation_chain: str | Sequence[str] | None = None,
delete_job_after_completion: bool = False,
aws_role_arn: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -869,6 +883,7 @@ def __init__(
self.timeout = timeout
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self.aws_role_arn = aws_role_arn
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand Down Expand Up @@ -919,6 +934,9 @@ def _create_body(self) -> dict:
if self.transfer_options is not None:
body[TRANSFER_SPEC][TRANSFER_OPTIONS] = self.transfer_options # type: ignore[index]

if self.aws_role_arn is not None:
body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ROLE_ARN] = self.aws_role_arn # type: ignore[index]

return body


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
ACCESS_KEY_ID,
AWS_ACCESS_KEY,
AWS_ROLE_ARN,
AWS_S3_DATA_SOURCE,
BUCKET_NAME,
FILTER_JOB_NAMES,
Expand Down Expand Up @@ -75,6 +76,7 @@
OPERATION_NAME = "transferOperations/transferJobs-123-456"
AWS_BUCKET_NAME = "aws-bucket-name"
GCS_BUCKET_NAME = "gcp-bucket-name"
AWS_ROLE_ARN_INPUT = "aRoleARn"
SOURCE_PATH = None
DESTINATION_PATH = None
DESCRIPTION = "description"
Expand Down Expand Up @@ -104,6 +106,9 @@
}

SOURCE_AWS = {AWS_S3_DATA_SOURCE: {BUCKET_NAME: AWS_BUCKET_NAME, PATH: SOURCE_PATH}}
SOURCE_AWS_ROLE_ARN = {
AWS_S3_DATA_SOURCE: {BUCKET_NAME: AWS_BUCKET_NAME, PATH: SOURCE_PATH, AWS_ROLE_ARN: AWS_ROLE_ARN_INPUT}
}
SOURCE_GCS = {GCS_DATA_SOURCE: {BUCKET_NAME: GCS_BUCKET_NAME, PATH: SOURCE_PATH}}
SOURCE_HTTP = {HTTP_DATA_SOURCE: {LIST_URL: "http://example.com"}}

Expand All @@ -122,6 +127,8 @@
VALID_TRANSFER_JOB_GCS[TRANSFER_SPEC].update(deepcopy(SOURCE_GCS))
VALID_TRANSFER_JOB_AWS = deepcopy(VALID_TRANSFER_JOB_BASE)
VALID_TRANSFER_JOB_AWS[TRANSFER_SPEC].update(deepcopy(SOURCE_AWS))
VALID_TRANSFER_JOB_AWS_ROLE_ARN = deepcopy(VALID_TRANSFER_JOB_BASE)
VALID_TRANSFER_JOB_AWS_ROLE_ARN[TRANSFER_SPEC].update(deepcopy(SOURCE_AWS_ROLE_ARN))

VALID_TRANSFER_JOB_GCS = {
NAME: JOB_NAME,
Expand All @@ -146,6 +153,9 @@
VALID_TRANSFER_JOB_AWS_RAW = deepcopy(VALID_TRANSFER_JOB_RAW)
VALID_TRANSFER_JOB_AWS_RAW[TRANSFER_SPEC].update(deepcopy(SOURCE_AWS))
VALID_TRANSFER_JOB_AWS_RAW[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = TEST_AWS_ACCESS_KEY
VALID_TRANSFER_JOB_AWS_WITH_ROLE_ARN_RAW = deepcopy(VALID_TRANSFER_JOB_RAW)
VALID_TRANSFER_JOB_AWS_WITH_ROLE_ARN_RAW[TRANSFER_SPEC].update(deepcopy(SOURCE_AWS_ROLE_ARN))
VALID_TRANSFER_JOB_AWS_WITH_ROLE_ARN_RAW[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ROLE_ARN] = AWS_ROLE_ARN_INPUT

VALID_OPERATION = {NAME: "operation-name"}

Expand All @@ -167,6 +177,16 @@ def test_should_inject_aws_credentials(self, mock_hook):
body = TransferJobPreprocessor(body=body).process_body()
assert body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] == TEST_AWS_ACCESS_KEY

@mock.patch("airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook")
def test_should_not_inject_aws_credentials(self, mock_hook):
mock_hook.return_value.get_credentials.return_value = Credentials(
TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None
)

body = {TRANSFER_SPEC: deepcopy(SOURCE_AWS_ROLE_ARN)}
body = TransferJobPreprocessor(body=body).process_body()
assert AWS_ACCESS_KEY not in body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE]

@pytest.mark.parametrize("field_attr", [SCHEDULE_START_DATE, SCHEDULE_END_DATE])
def test_should_format_date_from_python_to_dict(self, field_attr):
body = {SCHEDULE: {field_attr: NATIVE_DATE}}
Expand Down Expand Up @@ -239,7 +259,9 @@ def test_verify_data_source(self, transfer_spec):
"gcsDataSource, awsS3DataSource and httpDataSource." in str(err)
)

@pytest.mark.parametrize("body", [VALID_TRANSFER_JOB_GCS, VALID_TRANSFER_JOB_AWS])
@pytest.mark.parametrize(
"body", [VALID_TRANSFER_JOB_GCS, VALID_TRANSFER_JOB_AWS, VALID_TRANSFER_JOB_AWS_ROLE_ARN]
)
def test_verify_success(self, body):
try:
TransferJobValidator(body=body).validate_body()
Expand Down Expand Up @@ -304,6 +326,34 @@ def test_job_create_aws(self, aws_hook, mock_hook):

assert result == VALID_TRANSFER_JOB_AWS

@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
@mock.patch("airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook")
def test_job_create_aws_with_role_arn(self, aws_hook, mock_hook):
mock_hook.return_value.create_transfer_job.return_value = VALID_TRANSFER_JOB_AWS_ROLE_ARN
body = deepcopy(VALID_TRANSFER_JOB_AWS_ROLE_ARN)
del body["name"]
op = CloudDataTransferServiceCreateJobOperator(
body=body,
task_id=TASK_ID,
google_impersonation_chain=IMPERSONATION_CHAIN,
)

result = op.execute(context=mock.MagicMock())

mock_hook.assert_called_once_with(
api_version="v1",
gcp_conn_id="google_cloud_default",
impersonation_chain=IMPERSONATION_CHAIN,
)

mock_hook.return_value.create_transfer_job.assert_called_once_with(
body=VALID_TRANSFER_JOB_AWS_WITH_ROLE_ARN_RAW
)

assert result == VALID_TRANSFER_JOB_AWS_ROLE_ARN

@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
Expand Down