From 866120255888f31136fcf82f9f3d640450f45eb2 Mon Sep 17 00:00:00 2001 From: requiredfield Date: Sun, 17 May 2015 23:40:26 -0400 Subject: [PATCH] support expect_fingerprint as bytes --- aiohttp/connector.py | 40 +++++++++++++++++++++++++++++----------- tests/test_connector.py | 34 +++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index c1627c53b86..c6671c4af34 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -28,11 +28,20 @@ PY_34 = sys.version_info >= (3, 4) PY_343 = sys.version_info >= (3, 4, 3) -HASHFUNC_BY_DIGESTLEN = { +HASHFUNC_BY_HEXDIGESTLEN = { 32: md5, 40: sha1, 64: sha256, } +HASHFUNC_BY_BINDIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, +} +HASHFUNCMAP_BY_DIGEST_TYPE = { + str: HASHFUNC_BY_HEXDIGESTLEN, + bytes: HASHFUNC_BY_BINDIGESTLEN, +} class Connection(object): @@ -356,9 +365,10 @@ class TCPConnector(BaseConnector): """TCP connector. :param bool verify_ssl: Set to True to check ssl certifications. - :param str expect_fingerprint: Set to the md5, sha1, or sha256 fingerprint - (as a hexadecimal string) of the expected certificate (DER-encoded) - to verify the cert matches. May be interspersed with colons. + :param str expect_fingerprint: Pass the md5, sha1, or sha256 fingerprint + as either a hexadecimal string or binary bytestring of the expected + certificate (in DER format) to verify the cert matches. + If passing a hex string, colons and case are ignored. :param bool resolve: Set to True to do DNS lookup for host name. :param family: socket address family :param args: see :class:`BaseConnector` @@ -378,15 +388,23 @@ def __init__(self, *, verify_ssl=True, expect_fingerprint=None, self._verify_ssl = verify_ssl if expect_fingerprint: - expect_fingerprint = expect_fingerprint.replace(':', '').lower() - digestlen = len(expect_fingerprint) - hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen) + xfp = expect_fingerprint + digest_type = type(xfp) + hashfuncmap = HASHFUNCMAP_BY_DIGEST_TYPE.get(digest_type) + if not hashfuncmap: + raise TypeError('expect_fingerprint must be str or bytes') + is_str = digest_type is str + if is_str: + xfp = xfp.replace(':', '').lower() + digestlen = len(xfp) + hashfunc = hashfuncmap.get(digestlen) if not hashfunc: - raise ValueError('Fingerprint is of invalid length.') + raise ValueError('expect_fingerprint has invalid length') self._hashfunc = hashfunc - self._fingerprint_bytes = unhexlify(expect_fingerprint) - - self._expect_fingerprint = expect_fingerprint + self._fingerprint_bytes = unhexlify(xfp) if is_str else xfp + self._expect_fingerprint = xfp if is_str else hexlify(xfp) + else: + self._expect_fingerprint = None self._ssl_context = ssl_context self._family = family self._resolve = resolve diff --git a/tests/test_connector.py b/tests/test_connector.py index 68c8dc78176..3b8a693814d 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -468,19 +468,33 @@ def test_tcp_connector_expect_fingerprint_invalid_len(self): with self.assertRaises(ValueError): aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid) + def test_tcp_connector_expect_fingerprint_invalid_type(self): + invalid = 123 + with self.assertRaises(TypeError): + aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid) + def test_tcp_connector_expect_fingerprint(self): - # the even-index fingerprints below are for sample.crt.der, - # the certificate presented by test_utils.run_server + # The even-index fingerprints below are "expect success" cases + # for ./sample.crt.der, the cert presented by test_utils.run_server. + # The odd-index fingerprints are "expect fail" cases. testcases = ( # md5 - 'a20647adaaf5d85c4a995e62793b063d', # good - 'ffffffffffffffffffffffffffffffff', # bad + 'a2:06:47:ad:aa:f5:d8:5c:4a:99:5e:62:79:3b:06:3d', # good + 'ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff', # bad + + 'A20647ADAAF5D85C4A995E62793B063D', # colons and case ignored + 'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF', + + b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=', # bytes ok too + b'\xff' * 16, + # sha1 - '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9', # good - 'ffffffffffffffffffffffffffffffffffffffff', # bad + '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9', + 'ffffffffffffffffffffffffffffffffffffffff', + # sha256 - '309ac94483dc9127889111a16497fdcb7e37551444404c11ab99a8aeb714ee8b', # good # flake8: noqa - 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', # bad # flake8: noqa + '309ac94483dc9127889111a16497fdcb7e37551444404c11ab99a8aeb714ee8b', # flake8: noqa + 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', # flake8: noqa ) for i, fingerprint in enumerate(testcases): expect_fail = i % 2 @@ -490,10 +504,8 @@ def test_tcp_connector_expect_fingerprint(self): coro = client.request('get', httpd.url('method', 'get'), connector=conn, loop=self.loop) if expect_fail: - with self.assertRaises(FingerprintMismatch) as cm: + with self.assertRaises(FingerprintMismatch): self.loop.run_until_complete(coro) - self.assertEqual(cm.exception.expected, fingerprint) - self.assertEqual(cm.exception.got, testcases[i-1]) else: # should not raise self.loop.run_until_complete(coro)