Skip to content

Commit

Permalink
Add acl_policy into GCSToS3Operator (#10804)
Browse files Browse the repository at this point in the history
  • Loading branch information
amaterasu-coder committed Oct 5, 2020
1 parent baa980f commit d9586ff
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class GCSToS3Operator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
:param s3_acl_policy: Optional The string to specify the canned ACL policy for the
object to be uploaded in S3
:type s3_acl_policy: str
"""

template_fields: Iterable[str] = (
Expand Down Expand Up @@ -109,6 +112,7 @@ def __init__(
replace=False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dest_s3_extra_args: Optional[Dict] = None,
s3_acl_policy: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -133,6 +137,7 @@ def __init__(
self.replace = replace
self.google_impersonation_chain = google_impersonation_chain
self.dest_s3_extra_args = dest_s3_extra_args or {}
self.s3_acl_policy = s3_acl_policy

def execute(self, context):
# list all files in an Google Cloud Storage bucket
Expand Down Expand Up @@ -177,7 +182,9 @@ def execute(self, context):
dest_key = self.dest_s3_key + file
self.log.info("Saving file to %s", dest_key)

s3_hook.load_bytes(file_bytes, key=dest_key, replace=self.replace)
s3_hook.load_bytes(
file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
)

self.log.info("All done, uploaded %d files to S3", len(files))
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PREFIX = 'TEST'
S3_BUCKET = 's3://bucket/'
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
S3_ACL_POLICY = "private-read"


class TestGCSToS3Operator(unittest.TestCase):
Expand Down Expand Up @@ -240,3 +241,34 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
s3_mock_hook.assert_called_once_with(
aws_conn_id='aws_default', extra_args={'ContentLanguage': 'value'}, verify=None
)

# Test6: s3_acl_policy parameter is set
@mock_s3
@mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
@mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
@mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.load_bytes')
def test_execute_with_s3_acl_policy(self, mock_load_bytes, mock_gcs_hook, mock_gcs_hook2):
mock_gcs_hook.return_value.list.return_value = MOCK_FILES
mock_gcs_hook.return_value.download.return_value = b"testing"
mock_gcs_hook2.return_value.list.return_value = MOCK_FILES

operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
dest_s3_key=S3_BUCKET,
replace=False,
s3_acl_policy=S3_ACL_POLICY,
)

# Create dest bucket without files
hook = S3Hook(aws_conn_id='airflow_gcs_test')
bucket = hook.get_bucket('bucket')
bucket.create()

operator.execute(None)

# Make sure the acl_policy parameter is passed to the upload method
self.assertEqual(mock_load_bytes.call_args.kwargs['acl_policy'], S3_ACL_POLICY)

0 comments on commit d9586ff

Please sign in to comment.