diff --git a/docs/api.md b/docs/api.md index 9c62f39d..15a8466a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -95,7 +95,7 @@ progress_hooks (list): List of callbacks data_dir (str): Path to custom update folder -headers (dict): A urllib3.utils.make_headers compatible dictionary +headers (dict): A dictionary of generic and/or urllib3.utils.make_headers compatible headers test (bool): Used to initialize a test client diff --git a/pyupdater/client/__init__.py b/pyupdater/client/__init__.py index f0785e38..4a2ecea6 100644 --- a/pyupdater/client/__init__.py +++ b/pyupdater/client/__init__.py @@ -85,7 +85,7 @@ class Client(object): data_dir (str): Path to custom update folder - headers (dict): A urllib3.utils.make_headers compatible dictionary + headers (dict): A dictionary of generic and/or urllib3.utils.make_headers compatible headers test (bool): Used to initialize a test client @@ -130,8 +130,6 @@ def __init__(self, obj, **kwargs): raise SyntaxError("progress_hooks must be provided as a list.") self.progress_hooks += progress_hooks - obj.URLLIB3_HEADERS = headers - # A super dict used to save config info & be dot accessed config = _Config() config.from_object(obj) @@ -198,8 +196,8 @@ def __init__(self, obj, **kwargs): # The name of the key file to download self.key_file = settings.KEY_FILE_FILENAME - # urllib3 headers - self.urllib3_headers = obj.URLLIB3_HEADERS + # headers + self.headers = headers # Creating data & update directories self._setup() @@ -244,7 +242,7 @@ def _gen_file_downloader_options(self): "http_timeout": self.http_timeout, "max_download_retries": self.max_download_retries, "progress_hooks": self.progress_hooks, - "urllib3_headers": self.urllib3_headers, + "headers": self.headers, "verify": self.verify, } @@ -317,7 +315,7 @@ def _update_check(self, name, version, channel, strict): "verify": self.verify, "max_download_retries": self.max_download_retries, "progress_hooks": list(set(self.progress_hooks)), - "urllib3_headers": self.urllib3_headers, + "headers": self.headers, "downloader": self.downloader, } @@ -436,7 +434,8 @@ def _get_manifest_from_http(self): vf, self.update_urls, verify=self.verify, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) data = fd.download_verify_return() try: @@ -468,7 +467,8 @@ def _get_key_data(self): self.key_file, self.update_urls, verify=self.verify, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) data = fd.download_verify_return() try: diff --git a/pyupdater/client/downloader.py b/pyupdater/client/downloader.py index da2d1c0c..2bd3cc08 100755 --- a/pyupdater/client/downloader.py +++ b/pyupdater/client/downloader.py @@ -24,6 +24,7 @@ # ------------------------------------------------------------------------------ from __future__ import unicode_literals import hashlib +import inspect import logging import os import time @@ -131,7 +132,9 @@ def __init__(self, *args, **kwargs): self.content_length = None # Extra headers - self.headers = kwargs.get("urllb3_headers") + self.headers = kwargs.get("headers") + + self.http_timeout = kwargs.get("http_timeout") if self.verify is True: self.http_pool = self._get_http_pool() @@ -141,15 +144,25 @@ def __init__(self, *args, **kwargs): def _get_http_pool(self, secure=True): if secure: _http = urllib3.PoolManager( - cert_reqs=str("CERT_REQUIRED"), ca_certs=certifi.where() + cert_reqs=str("CERT_REQUIRED"), ca_certs=certifi.where(), timeout=self.http_timeout ) else: - _http = urllib3.PoolManager() + _http = urllib3.PoolManager(timeout=self.http_timeout) if self.headers: - _headers = urllib3.util.make_headers(**self.headers) + if six.PY3: + # Python3 + urllib_keys = inspect.getfullargspec(urllib3.util.make_headers).args + else: + # Python2 fallback + urllib_keys = inspect.getargspec(urllib3.util.make_headers).args + urllib_headers = {header: value for header, value in six.iteritems(self.headers) if header in urllib_keys} + other_headers = {header: value for header, value in six.iteritems(self.headers) if header not in urllib_keys} + _headers = urllib3.util.make_headers(**urllib_headers) + _headers.update(other_headers) _http.headers.update(_headers) log.debug(_http.headers) + log.debug("HTTP Timeout is " + str(self.http_timeout)) return _http def download_verify_write(self): diff --git a/pyupdater/client/patcher.py b/pyupdater/client/patcher.py index 4747393d..3536f279 100755 --- a/pyupdater/client/patcher.py +++ b/pyupdater/client/patcher.py @@ -67,7 +67,9 @@ class Patcher(object): max_download_retries (int): Number of times to retry a download - urllib3_headers (dict): Headers to be used with http request + headers (dict): Headers to be used with http request. Accepts urllib3 and generic headers. + + http_timeout (int): HTTP timeout or None """ def __init__(self, **kwargs): @@ -81,8 +83,9 @@ def __init__(self, **kwargs): self.update_urls = kwargs.get("update_urls", []) self.verify = kwargs.get("verify", True) self.max_download_retries = kwargs.get("max_download_retries") - self.urllib3_headers = kwargs.get("urllib3_headers") + self.headers = kwargs.get("headers") self.downloader = kwargs.get("downloader") + self.http_timeout = kwargs.get("http_timeout") # Progress hooks to be called self.progress_hooks = kwargs.get("progress_hooks", []) @@ -302,7 +305,8 @@ def _download_verify_patches(self): hexdigest=p["patch_hash"], verify=self.verify, max_download_retries=self.max_download_retries, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) # Attempt to download resource diff --git a/pyupdater/client/updates.py b/pyupdater/client/updates.py index 79109c24..e48ad75d 100644 --- a/pyupdater/client/updates.py +++ b/pyupdater/client/updates.py @@ -418,8 +418,8 @@ def __init__(self, data=None): # Weather or not the verify the https connection self.verify = data.get("verify", True) - # Extra headers to pass to urllib3 - self.urllib3_headers = data.get("urllib3_headers") + # Extra headers + self.headers = data.get("headers") # The amount of times to retry a url before giving up self.max_download_retries = data.get("max_download_retries") @@ -684,7 +684,8 @@ def _full_update(self): verify=self.verify, progress_hooks=self.progress_hooks, max_download_retries=self.max_download_retries, - urllb3_headers=self.urllib3_headers, + headers=self.headers, + http_timeout=self.http_timeout ) result = fd.download_verify_write() if result: diff --git a/tests/test_downloader.py b/tests/test_downloader.py index 3611d896..f9fe70c6 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -63,10 +63,20 @@ def test_return_fail(self, download_max_size): @pytest.mark.usefixtue("cleandir") -class TestBasicAuth(object): +class TestBasicAuthUrlLib(object): def test_basic_auth(self): headers = {"basic_auth": "user:pass"} - fd = FileDownloader("test", ["test"], urllb3_headers=headers) + fd = FileDownloader("test", ["test"], headers=headers) + http = fd._get_http_pool(secure=True) + sc = http.request("GET", "https://httpbin.org/basic-auth/user/pass").status + assert sc == 200 + + +@pytest.mark.usefixtue("cleandir") +class TestAuthorizationHeader(object): + def test_auth_header(self): + headers = {"Authorization": "Basic dXNlcjpwYXNz"} + fd = FileDownloader("test", ["test"], headers=headers) http = fd._get_http_pool(secure=True) sc = http.request("GET", "https://httpbin.org/basic-auth/user/pass").status assert sc == 200