From f0d80d67fdac689e38f83908f2c29501e19e2025 Mon Sep 17 00:00:00 2001 From: tim-s-ccs Date: Tue, 16 Jan 2024 12:54:50 +0000 Subject: [PATCH] Because we are using cloudfront as a load balancer for the assets the URL we generate needs to include the region. However boto3 does not do this so we are having to make use of the AWS CLI to create this URL. --- dmutils/__init__.py | 2 +- dmutils/s3.py | 37 +++++++++++++++++++++++++++++-------- tests/test_s3.py | 24 ++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/dmutils/__init__.py b/dmutils/__init__.py index b9ff6e14..b6908b25 100644 --- a/dmutils/__init__.py +++ b/dmutils/__init__.py @@ -2,4 +2,4 @@ from .flask_init import init_app -__version__ = '60.11.1' +__version__ = '60.12.0' diff --git a/dmutils/s3.py b/dmutils/s3.py index 992d8272..137b308a 100644 --- a/dmutils/s3.py +++ b/dmutils/s3.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +import subprocess import boto3 import datetime from dateutil.parser import parse as parse_time @@ -135,14 +136,34 @@ def get_signed_url(self, path, expires_in=30): """ path = self._normalize_path(path) if self.path_exists(path): - return self._resource.meta.client.generate_presigned_url( - "get_object", - Params={ - "Bucket": self._bucket.name, - "Key": path, - }, - ExpiresIn=expires_in, - ) + # Because we are using cloudfront as a load balancer for the assets + # the URL we generate needs to include the region. + # However boto3 does not do this so we are having to make use of + # the AWS CLI to create this URL. + if os.getenv("DM_ENVIRONMENT", None) == "native-aws": + result = subprocess.run( + [ + 'aws', + 's3', + 'presign', + f's3://{self._bucket.name}/{path}', + '--expires-in', + str(expires_in) + ], + stdout=subprocess.PIPE, + check=True, + ) + + return result.stdout.decode() + else: + return self._resource.meta.client.generate_presigned_url( + "get_object", + Params={ + "Bucket": self._bucket.name, + "Key": path, + }, + ExpiresIn=expires_in, + ) def _get_key(self, path): path = self._normalize_path(path) diff --git a/tests/test_s3.py b/tests/test_s3.py index 89b2f35a..fcbe8d73 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -165,6 +165,30 @@ def test_get_signed_url(self, bucket_with_file): assert parsed_qs["AWSAccessKeyId"] == ["AKIAIABCDABCDABCDABC"] assert parsed_qs["Signature"] + def test_get_signed_url_aws_native(self, bucket_with_file, monkeypatch): + with mock.patch('dmutils.s3.subprocess.run') as mock_run: + mock_stdout = mock.MagicMock() + mock_stdout.configure_mock( + **{ + "stdout.decode.return_value": 'https://dear-liza.s3.amazonaws.com/with/straw.dear.pdf?' + 'X-Amz-Algorithm=AnAlgorithm&' + 'X-Amz-Credential=SomeCredentials&' + 'X-Amz-Signature=SomeSignature' + } + ) + + mock_run.return_value = mock_stdout + + monkeypatch.setenv('DM_ENVIRONMENT', 'native-aws') + signed_url = S3('dear-liza').get_signed_url('with/straw.dear.pdf') + parsed_signed_url = urlparse(signed_url) + assert "dear-liza" in parsed_signed_url.hostname + assert parsed_signed_url.path == "/with/straw.dear.pdf" + parsed_qs = parse_qs(parsed_signed_url.query) + assert parsed_qs["X-Amz-Algorithm"][0] == "AnAlgorithm" + assert parsed_qs["X-Amz-Credential"][0] == "SomeCredentials" + assert parsed_qs["X-Amz-Signature"][0] == "SomeSignature" + @freeze_time('2015-10-10') def test_get_signed_url_with_expires_at(self, bucket_with_file): signed_url = S3('dear-liza').get_signed_url('with/straw.dear.pdf', expires_in=10)