diff --git a/docs/backends/azure.rst b/docs/backends/azure.rst index d59282d6..09c05e3b 100644 --- a/docs/backends/azure.rst +++ b/docs/backends/azure.rst @@ -195,6 +195,18 @@ Settings Additionally, this setting can be used to configure the client retry settings. To see how follow the `Python retry docs `__. +``request_options`` or ``AZURE_REQUEST_OPTIONS`` + + Default: ``{"client_request_id": django_guid.get_guid}`` + + A dict of kwarg options to set on each request for the ``BlobServiceClient``. A partial list of options can be found + `in the client docs `__. + + A no-argument callable can be used to set the value at request time. If django-guid is installed then the request id will be + set by default. If you have a custom request id generator and you can use it like so:: + + AZURE_REQUEST_OPTIONS = {"client_request_id": my_request_generator_function} + ``api_version`` or ``AZURE_API_VERSION`` Default: ``None`` diff --git a/storages/backends/azure_storage.py b/storages/backends/azure_storage.py index a696e220..de2884c8 100644 --- a/storages/backends/azure_storage.py +++ b/storages/backends/azure_storage.py @@ -24,6 +24,11 @@ from storages.utils import setting from storages.utils import to_bytes +try: + from django_guid import get_guid +except ImportError: + def get_guid(): + return None @deconstructible class AzureStorageFile(File): @@ -47,7 +52,7 @@ def _get_file(self): if "r" in self._mode or "a" in self._mode: download_stream = self._storage.client.download_blob( - self._path, timeout=self._storage.timeout + self._path, **self._storage._request_options() ) download_stream.readinto(file) if "r" in self._mode: @@ -132,6 +137,22 @@ def __init__(self, **settings): if not self.account_key and "AccountKey" in parsed: self.account_key = parsed["AccountKey"] + def _request_options(self): + """ + If callables were provided in request_options, evaluate them and return + the concrete values. Include "timeout", which was a previously-supported + request option before the introduction of the request_options setting. + """ + if not self.request_options: + return {"timeout": self.timeout} + callable_allowed = ("raw_response_hook", "raw_request_hook") + options = self.request_options.copy() + options["timeout"] = self.timeout + for key, value in self.request_options.items(): + if key not in callable_allowed and callable(value): + options[key] = value() + return options + def get_default_settings(self): return { "account_name": setting("AZURE_ACCOUNT_NAME"), @@ -154,6 +175,7 @@ def get_default_settings(self): "token_credential": setting("AZURE_TOKEN_CREDENTIAL"), "api_version": setting("AZURE_API_VERSION", None), "client_options": setting("AZURE_CLIENT_OPTIONS", {}), + "request_options": setting("AZURE_REQUEST_OPTIONS", {"client_request_id": get_guid}), } def _get_service_client(self): @@ -252,13 +274,13 @@ def exists(self, name): def delete(self, name): try: - self.client.delete_blob(self._get_valid_path(name), timeout=self.timeout) + self.client.delete_blob(self._get_valid_path(name), **self._request_options()) except ResourceNotFoundError: pass def size(self, name): blob_client = self.client.get_blob_client(self._get_valid_path(name)) - properties = blob_client.get_blob_properties(timeout=self.timeout) + properties = blob_client.get_blob_properties(**self._request_options()) return properties.size def _save(self, name, content): @@ -276,8 +298,8 @@ def _save(self, name, content): content, content_settings=ContentSettings(**params), max_concurrency=self.upload_max_conn, - timeout=self.timeout, overwrite=self.overwrite_files, + **self._request_options(), ) return cleaned_name @@ -350,7 +372,7 @@ def get_modified_time(self, name): USE_TZ is True, otherwise returns a naive datetime in the local timezone. """ blob_client = self.client.get_blob_client(self._get_valid_path(name)) - properties = blob_client.get_blob_properties(timeout=self.timeout) + properties = blob_client.get_blob_properties(**self._request_options()) if not setting("USE_TZ", False): return timezone.make_naive(properties.last_modified) @@ -372,7 +394,7 @@ def list_all(self, path=""): return [ blob.name for blob in self.client.list_blobs( - name_starts_with=path, timeout=self.timeout + name_starts_with=path, **self._request_options() ) ] diff --git a/tests/test_azure.py b/tests/test_azure.py index 3700b60d..466ebe27 100644 --- a/tests/test_azure.py +++ b/tests/test_azure.py @@ -270,6 +270,7 @@ def test_storage_save(self): max_concurrency=2, timeout=20, overwrite=True, + client_request_id=None, ) c_mocked.assert_called_once_with( content_type="text/plain", content_encoding=None, cache_control=None @@ -293,6 +294,7 @@ def test_storage_open_write(self): max_concurrency=2, timeout=20, overwrite=True, + client_request_id=None, ) def test_storage_exists(self): @@ -308,7 +310,9 @@ def test_storage_exists(self): def test_delete_blob(self): self.storage.delete("name") - self.storage._client.delete_blob.assert_called_once_with("name", timeout=20) + self.storage._client.delete_blob.assert_called_once_with( + "name", timeout=20, client_request_id=None + ) def test_storage_listdir_base(self): file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"] @@ -322,7 +326,7 @@ def test_storage_listdir_base(self): dirs, files = self.storage.listdir("") self.storage._client.list_blobs.assert_called_with( - name_starts_with="", timeout=20 + name_starts_with="", timeout=20, client_request_id=None ) self.assertEqual(len(dirs), 0) @@ -378,3 +382,23 @@ def test_client_settings(self, bsc): bsc.assert_called_once_with( "https://test.blob.core.windows.net", credential=None, api_version="1.3" ) + + def test_lazy_evaluated_request_options(self): + foo = mock.MagicMock() + foo.side_effect = [1, 2] # return different values the two times it is called + with override_settings( + AZURE_REQUEST_OPTIONS={"key1": 5, "client_request_id": foo} + ): + storage = azure_storage.AzureStorage() + client_mock = mock.MagicMock() + storage._client = client_mock + + _, _ = storage.listdir("") + client_mock.list_blobs.assert_called_with( + name_starts_with="", timeout=20, key1=5, client_request_id=1 + ) + + _, _ = storage.listdir("") + client_mock.list_blobs.assert_called_with( + name_starts_with="", timeout=20, key1=5, client_request_id=2 + )