From 4e66d78e1cbb48c19d6445f60e6ff7183b03375f Mon Sep 17 00:00:00 2001 From: requiredfield Date: Wed, 13 May 2015 19:45:49 -0400 Subject: [PATCH] first pass at verify_fingerprint --- aiohttp/connector.py | 38 ++++++++++++++++++++++++++++++++++---- aiohttp/errors.py | 4 ++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 392309b2166..94667358d91 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -8,7 +8,9 @@ import traceback import warnings +from binascii import unhexlify from collections import defaultdict +from hashlib import md5, sha1, sha256 from itertools import chain from math import ceil @@ -16,7 +18,7 @@ from .client import ClientRequest from .errors import ServerDisconnectedError from .errors import HttpProxyError, ProxyConnectionError -from .errors import ClientOSError, ClientTimeoutError +from .errors import ClientOSError, ClientTimeoutError, FingerprintMismatch from .helpers import BasicAuth @@ -25,6 +27,12 @@ PY_34 = sys.version_info >= (3, 4) PY_343 = sys.version_info >= (3, 4, 3) +HASHFUNC_BY_DIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, +} + class Connection(object): @@ -347,13 +355,15 @@ class TCPConnector(BaseConnector): """TCP connector. :param bool verify_ssl: Set to True to check ssl certifications. + :param str verify_fingerprint: Set to a string of hex digits to + verify the ssl cert fingerprint matches. :param bool resolve: Set to True to do DNS lookup for host name. :param family: socket address family :param args: see :class:`BaseConnector` :param kwargs: see :class:`BaseConnector` """ - def __init__(self, *, verify_ssl=True, + def __init__(self, *, verify_ssl=True, verify_fingerprint=None, resolve=False, family=socket.AF_INET, ssl_context=None, **kwargs): super().__init__(**kwargs) @@ -364,6 +374,14 @@ def __init__(self, *, verify_ssl=True, "verify_ssl=False or specify ssl_context, not both.") self._verify_ssl = verify_ssl + if verify_fingerprint: + verify_fingerprint = verify_fingerprint.replace(':', '').lower() + digestlen, odd = divmod(len(verify_fingerprint), 2) + if odd or digestlen not in HASHFUNC_BY_DIGESTLEN: + raise ValueError('Fingerprint is of invalid length.') + self._hashfunc = HASHFUNC_BY_DIGESTLEN[digestlen] + self._fingerprint_bytes = unhexlify(verify_fingerprint) + self._verify_fingerprint = verify_fingerprint self._ssl_context = ssl_context self._family = family self._resolve = resolve @@ -374,6 +392,11 @@ def verify_ssl(self): """Do check for ssl certifications?""" return self._verify_ssl + @property + def verify_fingerprint(self): + """Verify ssl cert fingerprint matches?""" + return self._verify_fingerprint + @property def ssl_context(self): """SSLContext instance for https requests. @@ -464,11 +487,18 @@ def _create_connection(self, req): for hinfo in hosts: try: - return (yield from self._loop.create_connection( + conn = yield from self._loop.create_connection( self._factory, hinfo['host'], hinfo['port'], ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None)) + server_hostname=hinfo['hostname'] if sslcontext else None) + if self._verify_fingerprint: + sock = conn[0]._sock + cert = sock.getpeercert(True) + digest = self._hashfunc(cert).digest() + if digest != self._fingerprint_bytes: + raise FingerprintMismatch + return conn except OSError as e: exc = e else: diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 5c148638c1f..b2e62ca9592 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -170,3 +170,7 @@ class LineLimitExceededParserError(ParserError): def __init__(self, msg, limit): super().__init__(msg) self.limit = limit + + +class FingerprintMismatch(Exception): + """SSL certificate does not match expected fingerprint."""