diff --git a/CHANGES.txt b/CHANGES.txt index 4a9e019f338..fe9e4f9e3e7 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -4,6 +4,9 @@ CHANGES 0.16.0 (XX-XX-XXXX) ------------------- +- Support new `fingerprint` param of TCPConnector to enable verifying + ssl certificates via md5, sha1, or sha256 digest + - Setup uploaded filename if field value is binary and transfer encoding is not specified #349 diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3a7816cbc44..1dd78aa0492 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -38,6 +38,7 @@ Olaf Conradi Paul Colomiets Philipp A. Raúl Cumplido +"Required Field" Robert Lu Sebastian Hanula Simon Kennedy diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 392309b2166..2cc8ac186e3 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -9,6 +9,7 @@ import warnings from collections import defaultdict +from hashlib import md5, sha1, sha256 from itertools import chain from math import ceil @@ -17,6 +18,7 @@ from .errors import ServerDisconnectedError from .errors import HttpProxyError, ProxyConnectionError from .errors import ClientOSError, ClientTimeoutError +from .errors import FingerprintMismatch from .helpers import BasicAuth @@ -25,6 +27,12 @@ PY_34 = sys.version_info >= (3, 4) PY_343 = sys.version_info >= (3, 4, 3) +HASHFUNC_BY_DIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, +} + class Connection(object): @@ -347,13 +355,17 @@ class TCPConnector(BaseConnector): """TCP connector. :param bool verify_ssl: Set to True to check ssl certifications. + :param bytes fingerprint: Pass the binary md5, sha1, or sha256 + digest of the expected certificate in DER format to verify + the cert the server presents matches. See also + https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning :param bool resolve: Set to True to do DNS lookup for host name. :param family: socket address family :param args: see :class:`BaseConnector` :param kwargs: see :class:`BaseConnector` """ - def __init__(self, *, verify_ssl=True, + def __init__(self, *, verify_ssl=True, fingerprint=None, resolve=False, family=socket.AF_INET, ssl_context=None, **kwargs): super().__init__(**kwargs) @@ -364,6 +376,15 @@ def __init__(self, *, verify_ssl=True, "verify_ssl=False or specify ssl_context, not both.") self._verify_ssl = verify_ssl + + if fingerprint: + digestlen = len(fingerprint) + hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen) + if not hashfunc: + raise ValueError('fingerprint has invalid length') + self._hashfunc = hashfunc + self._fingerprint = fingerprint + self._ssl_context = ssl_context self._family = family self._resolve = resolve @@ -374,6 +395,11 @@ def verify_ssl(self): """Do check for ssl certifications?""" return self._verify_ssl + @property + def fingerprint(self): + """Expected ssl certificate fingerprint.""" + return self._fingerprint + @property def ssl_context(self): """SSLContext instance for https requests. @@ -464,11 +490,25 @@ def _create_connection(self, req): for hinfo in hosts: try: - return (yield from self._loop.create_connection( - self._factory, hinfo['host'], hinfo['port'], + host = hinfo['host'] + port = hinfo['port'] + conn = yield from self._loop.create_connection( + self._factory, host, port, ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None)) + server_hostname=hinfo['hostname'] if sslcontext else None) + transport = conn[0] + has_cert = transport.get_extra_info('sslcontext') + if has_cert and self._fingerprint: + sock = transport.get_extra_info('socket') + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + assert cert + got = self._hashfunc(cert).digest() + expected = self._fingerprint + if got != expected: + raise FingerprintMismatch(expected, got, host, port) + return conn except OSError as e: exc = e else: diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 5c148638c1f..b488c963c2c 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -13,6 +13,7 @@ 'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError', 'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError', 'ClientRequestError', 'ClientResponseError', + 'FingerprintMismatch', 'WSServerHandshakeError', 'WSClientDisconnectedError') @@ -170,3 +171,18 @@ class LineLimitExceededParserError(ParserError): def __init__(self, msg, limit): super().__init__(msg) self.limit = limit + + +class FingerprintMismatch(ClientConnectionError): + """SSL certificate does not match expected fingerprint.""" + + def __init__(self, expected, got, host, port): + self.expected = expected + self.got = got + self.host = host + self.port = port + + def __repr__(self): + return '<{} expected={} got={} host={} port={}>'.format( + self.__class__.__name__, self.expected, self.got, + self.host, self.port) diff --git a/docs/client.rst b/docs/client.rst index 8304cf0de01..bb58ae39841 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -396,21 +396,49 @@ By default it uses strict checks for HTTPS protocol. Certification checks can be relaxed by passing ``verify_ssl=False``:: >>> conn = aiohttp.TCPConnector(verify_ssl=False) - >>> r = yield from aiohttp.request( - ... 'get', 'https://example.com', connector=conn) + >>> session = aiohttp.ClientSession(connector=conn) + >>> r = yield from session.get('https://example.com') If you need to setup custom ssl parameters (use own certification files for example) you can create a :class:`ssl.SSLContext` instance and pass it into the connector:: - >>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - >>> sslcontext.verify_mode = ssl.CERT_REQUIRED - >>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt") + >>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt') >>> conn = aiohttp.TCPConnector(ssl_context=sslcontext) - >>> r = yield from aiohttp.request( - ... 'get', 'https://example.com', connector=conn) - + >>> session = aiohttp.ClientSession(connector=conn) + >>> r = yield from session.get('https://example.com') + +You may also verify certificates via md5, sha1, or sha256 fingerprint:: + + >>> # Attempt to connect to https://www.python.org + >>> # with a pin to a bogus certificate: + >>> bad_md5 = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=' + >>> conn = aiohttp.TCPConnector(fingerprint=bad_md5) + >>> session = aiohttp.ClientSession(connector=conn) + >>> exc = None + >>> try: + ... r = yield from session.get('https://www.python.org') + ... except FingerprintMismatch as e: + ... exc = e + >>> exc is not None + True + >>> exc.expected == bad_md5 + True + >>> exc.got # www.python.org cert's actual md5 + b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb' + +Note that this is the fingerprint of the DER-encoded certificate. +If you have the certificate in PEM format, you can convert it to +DER with e.g. ``openssl x509 -in crt.pem -inform PEM -outform DER > crt.der``. + +Tip: to convert from a hexadecimal digest to a binary bytestring, you can use +:attr:`binascii.unhexlify`:: + + >>> md5_hex = 'ca3b499c75768e7313384e243f15cacb' + >>> from binascii import unhexlify + >>> unhexlify(md5_hex) + b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb' Unix domain sockets ------------------- diff --git a/tests/sample.crt.der b/tests/sample.crt.der new file mode 100644 index 00000000000..ce22b75b9e0 Binary files /dev/null and b/tests/sample.crt.der differ diff --git a/tests/test_connector.py b/tests/test_connector.py index fc81d559045..b2c574af5bf 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -12,6 +12,7 @@ import aiohttp from aiohttp import client from aiohttp import test_utils +from aiohttp.errors import FingerprintMismatch from aiohttp.client import ClientResponse, ClientRequest from aiohttp.connector import Connection @@ -452,10 +453,57 @@ def test_cleanup3(self): def test_tcp_connector_ctor(self): conn = aiohttp.TCPConnector(loop=self.loop) self.assertTrue(conn.verify_ssl) + self.assertIs(conn.fingerprint, None) self.assertFalse(conn.resolve) self.assertEqual(conn.family, socket.AF_INET) self.assertEqual(conn.resolved_hosts, {}) + def test_tcp_connector_ctor_fingerprint_valid(self): + valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=' + conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid) + self.assertEqual(conn.fingerprint, valid) + + def test_tcp_connector_fingerprint_invalid(self): + invalid = b'\x00' + with self.assertRaises(ValueError): + aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid) + + def test_tcp_connector_fingerprint(self): + # The even-index fingerprints below are "expect success" cases + # for ./sample.crt.der, the cert presented by test_utils.run_server. + # The odd-index fingerprints are "expect fail" cases. + testcases = ( + # md5 + b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=', + b'\x00' * 16, + + # sha1 + b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9', + b'\x00' * 20, + + # sha256 + b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L' + b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b', + b'\x00' * 32, + ) + for i, fingerprint in enumerate(testcases): + expect_fail = i % 2 + conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False, + fingerprint=fingerprint) + with test_utils.run_server(self.loop, use_ssl=True) as httpd: + coro = client.request('get', httpd.url('method', 'get'), + connector=conn, loop=self.loop) + if expect_fail: + with self.assertRaises(FingerprintMismatch) as cm: + self.loop.run_until_complete(coro) + exc = cm.exception + self.assertEqual(exc.expected, fingerprint) + # the previous test case should be what we actually got + self.assertEqual(exc.got, testcases[i-1]) + else: + # should not raise + self.loop.run_until_complete(coro) + def test_tcp_connector_clear_resolved_hosts(self): conn = aiohttp.TCPConnector(loop=self.loop) info = object() @@ -897,7 +945,7 @@ def test_https_connect(self, ClientRequestMock): self.assertEqual(proxy_req.method, 'CONNECT') self.assertEqual(proxy_req.path, 'www.python.org:443') tr.pause_reading.assert_called_once_with() - tr.get_extra_info.assert_called_once_with('socket', default=None) + tr.get_extra_info.assert_called_with('socket', default=None) @unittest.mock.patch('aiohttp.connector.ClientRequest') def test_https_connect_runtime_error(self, ClientRequestMock):