From db6be117625ca915edf26f303811800500868591 Mon Sep 17 00:00:00 2001 From: Matt Simpson Date: Thu, 9 Jan 2020 17:29:46 -0600 Subject: [PATCH] Support generic headers and HTTP timeouts Instead of solely urllib3 header support, there is now the ability to pass generic headers such as "Authorization" in addition to the urllib3 headers (e.g. "basic_auth"). Furthermore, http timeouts are honored in the FileDownloader. Add test to check headers --- docs/api.md | 2 +- pyupdater/client/__init__.py | 18 +++++++++--------- pyupdater/client/downloader.py | 21 +++++++++++++++++---- pyupdater/client/patcher.py | 10 +++++++--- pyupdater/client/updates.py | 7 ++++--- tests/test_downloader.py | 14 ++++++++++++-- 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/docs/api.md b/docs/api.md index 9c62f39d..15a8466a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -95,7 +95,7 @@ progress_hooks (list): List of callbacks data_dir (str): Path to custom update folder -headers (dict): A urllib3.utils.make_headers compatible dictionary +headers (dict): A dictionary of generic and/or urllib3.utils.make_headers compatible headers test (bool): Used to initialize a test client diff --git a/pyupdater/client/__init__.py b/pyupdater/client/__init__.py index f0785e38..4a2ecea6 100644 --- a/pyupdater/client/__init__.py +++ b/pyupdater/client/__init__.py @@ -85,7 +85,7 @@ class Client(object): data_dir (str): Path to custom update folder - headers (dict): A urllib3.utils.make_headers compatible dictionary + headers (dict): A dictionary of generic and/or urllib3.utils.make_headers compatible headers test (bool): Used to initialize a test client @@ -130,8 +130,6 @@ def __init__(self, obj, **kwargs): raise SyntaxError("progress_hooks must be provided as a list.") self.progress_hooks += progress_hooks - obj.URLLIB3_HEADERS = headers - # A super dict used to save config info & be dot accessed config = _Config() config.from_object(obj) @@ -198,8 +196,8 @@ def __init__(self, obj, **kwargs): # The name of the key file to download self.key_file = settings.KEY_FILE_FILENAME - # urllib3 headers - self.urllib3_headers = obj.URLLIB3_HEADERS + # headers + self.headers = headers # Creating data & update directories self._setup() @@ -244,7 +242,7 @@ def _gen_file_downloader_options(self): "http_timeout": self.http_timeout, "max_download_retries": self.max_download_retries, "progress_hooks": self.progress_hooks, - "urllib3_headers": self.urllib3_headers, + "headers": self.headers, "verify": self.verify, } @@ -317,7 +315,7 @@ def _update_check(self, name, version, channel, strict): "verify": self.verify, "max_download_retries": self.max_download_retries, "progress_hooks": list(set(self.progress_hooks)), - "urllib3_headers": self.urllib3_headers, + "headers": self.headers, "downloader": self.downloader, } @@ -436,7 +434,8 @@ def _get_manifest_from_http(self): vf, self.update_urls, verify=self.verify, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) data = fd.download_verify_return() try: @@ -468,7 +467,8 @@ def _get_key_data(self): self.key_file, self.update_urls, verify=self.verify, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) data = fd.download_verify_return() try: diff --git a/pyupdater/client/downloader.py b/pyupdater/client/downloader.py index da2d1c0c..2bd3cc08 100755 --- a/pyupdater/client/downloader.py +++ b/pyupdater/client/downloader.py @@ -24,6 +24,7 @@ # ------------------------------------------------------------------------------ from __future__ import unicode_literals import hashlib +import inspect import logging import os import time @@ -131,7 +132,9 @@ def __init__(self, *args, **kwargs): self.content_length = None # Extra headers - self.headers = kwargs.get("urllb3_headers") + self.headers = kwargs.get("headers") + + self.http_timeout = kwargs.get("http_timeout") if self.verify is True: self.http_pool = self._get_http_pool() @@ -141,15 +144,25 @@ def __init__(self, *args, **kwargs): def _get_http_pool(self, secure=True): if secure: _http = urllib3.PoolManager( - cert_reqs=str("CERT_REQUIRED"), ca_certs=certifi.where() + cert_reqs=str("CERT_REQUIRED"), ca_certs=certifi.where(), timeout=self.http_timeout ) else: - _http = urllib3.PoolManager() + _http = urllib3.PoolManager(timeout=self.http_timeout) if self.headers: - _headers = urllib3.util.make_headers(**self.headers) + if six.PY3: + # Python3 + urllib_keys = inspect.getfullargspec(urllib3.util.make_headers).args + else: + # Python2 fallback + urllib_keys = inspect.getargspec(urllib3.util.make_headers).args + urllib_headers = {header: value for header, value in six.iteritems(self.headers) if header in urllib_keys} + other_headers = {header: value for header, value in six.iteritems(self.headers) if header not in urllib_keys} + _headers = urllib3.util.make_headers(**urllib_headers) + _headers.update(other_headers) _http.headers.update(_headers) log.debug(_http.headers) + log.debug("HTTP Timeout is " + str(self.http_timeout)) return _http def download_verify_write(self): diff --git a/pyupdater/client/patcher.py b/pyupdater/client/patcher.py index 4747393d..3536f279 100755 --- a/pyupdater/client/patcher.py +++ b/pyupdater/client/patcher.py @@ -67,7 +67,9 @@ class Patcher(object): max_download_retries (int): Number of times to retry a download - urllib3_headers (dict): Headers to be used with http request + headers (dict): Headers to be used with http request. Accepts urllib3 and generic headers. + + http_timeout (int): HTTP timeout or None """ def __init__(self, **kwargs): @@ -81,8 +83,9 @@ def __init__(self, **kwargs): self.update_urls = kwargs.get("update_urls", []) self.verify = kwargs.get("verify", True) self.max_download_retries = kwargs.get("max_download_retries") - self.urllib3_headers = kwargs.get("urllib3_headers") + self.headers = kwargs.get("headers") self.downloader = kwargs.get("downloader") + self.http_timeout = kwargs.get("http_timeout") # Progress hooks to be called self.progress_hooks = kwargs.get("progress_hooks", []) @@ -302,7 +305,8 @@ def _download_verify_patches(self): hexdigest=p["patch_hash"], verify=self.verify, max_download_retries=self.max_download_retries, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) # Attempt to download resource diff --git a/pyupdater/client/updates.py b/pyupdater/client/updates.py index 79109c24..e48ad75d 100644 --- a/pyupdater/client/updates.py +++ b/pyupdater/client/updates.py @@ -418,8 +418,8 @@ def __init__(self, data=None): # Weather or not the verify the https connection self.verify = data.get("verify", True) - # Extra headers to pass to urllib3 - self.urllib3_headers = data.get("urllib3_headers") + # Extra headers + self.headers = data.get("headers") # The amount of times to retry a url before giving up self.max_download_retries = data.get("max_download_retries") @@ -684,7 +684,8 @@ def _full_update(self): verify=self.verify, progress_hooks=self.progress_hooks, max_download_retries=self.max_download_retries, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) result = fd.download_verify_write() if result: diff --git a/tests/test_downloader.py b/tests/test_downloader.py index 3611d896..d6350eea 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -63,10 +63,20 @@ def test_return_fail(self, download_max_size): @pytest.mark.usefixtue("cleandir") -class TestBasicAuth(object): +class TestBasicAuthUrlLib(object): def test_basic_auth(self): headers = {"basic_auth": "user:pass"} - fd = FileDownloader("test", ["test"], urllb3_headers=headers) + fd = FileDownloader("test", ["test"], headers=headers) + http = fd._get_http_pool(secure=True) + sc = http.request("GET", "https://httpbin.org/basic-auth/user/pass").status + assert sc == 200 + + +@pytest.mark.usefixtue("cleandir") +class TestAuthorizationHeader(object): + def test_basic_auth(self): + headers = {"Authorization": "Basic dXNlcjpwYXNz"} + fd = FileDownloader("test", ["test"], headers=headers) http = fd._get_http_pool(secure=True) sc = http.request("GET", "https://httpbin.org/basic-auth/user/pass").status assert sc == 200