Skip to content

Commit

Permalink
support new TCPConnector param fingerprint
Browse files Browse the repository at this point in the history
enables ssl certificate pinning
  • Loading branch information
requiredfield committed May 19, 2015
1 parent fc7cbbf commit 0e1219a
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ CHANGES
0.16.0 (XX-XX-XXXX)
-------------------

- Support new `fingerprint` param of TCPConnector to enable verifying
ssl certificates via md5, sha1, or sha256 digest

- Setup uploaded filename if field value is binary and transfer
encoding is not specified #349

Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Olaf Conradi
Paul Colomiets
Philipp A.
Raúl Cumplido
"Required Field" <[email protected]>
Robert Lu
Sebastian Hanula
Simon Kennedy
Expand Down
48 changes: 44 additions & 4 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings

from collections import defaultdict
from hashlib import md5, sha1, sha256
from itertools import chain
from math import ceil

Expand All @@ -17,6 +18,7 @@
from .errors import ServerDisconnectedError
from .errors import HttpProxyError, ProxyConnectionError
from .errors import ClientOSError, ClientTimeoutError
from .errors import FingerprintMismatch
from .helpers import BasicAuth


Expand All @@ -25,6 +27,12 @@
PY_34 = sys.version_info >= (3, 4)
PY_343 = sys.version_info >= (3, 4, 3)

HASHFUNC_BY_DIGESTLEN = {
16: md5,
20: sha1,
32: sha256,
}


class Connection(object):

Expand Down Expand Up @@ -347,13 +355,17 @@ class TCPConnector(BaseConnector):
"""TCP connector.
:param bool verify_ssl: Set to True to check ssl certifications.
:param bytes fingerprint: Pass the binary md5, sha1, or sha256
digest of the expected certificate in DER format to verify
the cert the server presents matches. See also
https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
:param bool resolve: Set to True to do DNS lookup for host name.
:param family: socket address family
:param args: see :class:`BaseConnector`
:param kwargs: see :class:`BaseConnector`
"""

def __init__(self, *, verify_ssl=True,
def __init__(self, *, verify_ssl=True, fingerprint=None,
resolve=False, family=socket.AF_INET, ssl_context=None,
**kwargs):
super().__init__(**kwargs)
Expand All @@ -364,6 +376,15 @@ def __init__(self, *, verify_ssl=True,
"verify_ssl=False or specify ssl_context, not both.")

self._verify_ssl = verify_ssl

if fingerprint:
digestlen = len(fingerprint)
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
if not hashfunc:
raise ValueError('fingerprint has invalid length')
self._hashfunc = hashfunc
self._fingerprint = fingerprint

self._ssl_context = ssl_context
self._family = family
self._resolve = resolve
Expand All @@ -374,6 +395,11 @@ def verify_ssl(self):
"""Do check for ssl certifications?"""
return self._verify_ssl

@property
def fingerprint(self):
"""Expected ssl certificate fingerprint."""
return self._fingerprint

@property
def ssl_context(self):
"""SSLContext instance for https requests.
Expand Down Expand Up @@ -464,11 +490,25 @@ def _create_connection(self, req):

for hinfo in hosts:
try:
return (yield from self._loop.create_connection(
self._factory, hinfo['host'], hinfo['port'],
host = hinfo['host']
port = hinfo['port']
conn = yield from self._loop.create_connection(
self._factory, host, port,
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
server_hostname=hinfo['hostname'] if sslcontext else None))
server_hostname=hinfo['hostname'] if sslcontext else None)
transport = conn[0]
has_cert = transport.get_extra_info('sslcontext')
if has_cert and self._fingerprint:
sock = transport.get_extra_info('socket')
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = self._hashfunc(cert).digest()
expected = self._fingerprint
if got != expected:
raise FingerprintMismatch(expected, got, host, port)
return conn
except OSError as e:
exc = e
else:
Expand Down
16 changes: 16 additions & 0 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError',
'FingerprintMismatch',

'WSServerHandshakeError', 'WSClientDisconnectedError')

Expand Down Expand Up @@ -170,3 +171,18 @@ class LineLimitExceededParserError(ParserError):
def __init__(self, msg, limit):
super().__init__(msg)
self.limit = limit


class FingerprintMismatch(ClientConnectionError):
"""SSL certificate does not match expected fingerprint."""

def __init__(self, expected, got, host, port):
self.expected = expected
self.got = got
self.host = host
self.port = port

def __repr__(self):
return '<{} expected={} got={} host={} port={}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
44 changes: 36 additions & 8 deletions docs/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -396,21 +396,49 @@ By default it uses strict checks for HTTPS protocol. Certification
checks can be relaxed by passing ``verify_ssl=False``::

>>> conn = aiohttp.TCPConnector(verify_ssl=False)
>>> r = yield from aiohttp.request(
... 'get', 'https://example.com', connector=conn)
>>> session = aiohttp.ClientSession(connector=conn)
>>> r = yield from session.get('https://example.com')


If you need to setup custom ssl parameters (use own certification
files for example) you can create a :class:`ssl.SSLContext` instance and
pass it into the connector::

>>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
>>> sslcontext.verify_mode = ssl.CERT_REQUIRED
>>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt")
>>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt')
>>> conn = aiohttp.TCPConnector(ssl_context=sslcontext)
>>> r = yield from aiohttp.request(
... 'get', 'https://example.com', connector=conn)

>>> session = aiohttp.ClientSession(connector=conn)
>>> r = yield from session.get('https://example.com')

You may also verify certificates via md5, sha1, or sha256 fingerprint::

>>> # Attempt to connect to https://www.python.org
>>> # with a pin to a bogus certificate:
>>> bad_md5 = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
>>> conn = aiohttp.TCPConnector(fingerprint=bad_md5)
>>> session = aiohttp.ClientSession(connector=conn)
>>> exc = None
>>> try:
... r = yield from session.get('https://www.python.org')
... except FingerprintMismatch as e:
... exc = e
>>> exc is not None
True
>>> exc.expected == bad_md5
True
>>> exc.got # www.python.org cert's actual md5
b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb'

Note that this is the fingerprint of the DER-encoded certificate.
If you have the certificate in PEM format, you can convert it to
DER with e.g. ``openssl x509 -in crt.pem -inform PEM -outform DER > crt.der``.

Tip: to convert from a hexadecimal digest to a binary bytestring, you can use
:attr:`binascii.unhexlify`::

>>> md5_hex = 'ca3b499c75768e7313384e243f15cacb'
>>> from binascii import unhexlify
>>> unhexlify(md5_hex)
b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb'

Unix domain sockets
-------------------
Expand Down
Binary file added tests/sample.crt.der
Binary file not shown.
52 changes: 50 additions & 2 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import aiohttp
from aiohttp import client
from aiohttp import test_utils
from aiohttp.errors import FingerprintMismatch
from aiohttp.client import ClientResponse, ClientRequest
from aiohttp.connector import Connection

Expand Down Expand Up @@ -452,10 +453,57 @@ def test_cleanup3(self):
def test_tcp_connector_ctor(self):
conn = aiohttp.TCPConnector(loop=self.loop)
self.assertTrue(conn.verify_ssl)
self.assertIs(conn.fingerprint, None)
self.assertFalse(conn.resolve)
self.assertEqual(conn.family, socket.AF_INET)
self.assertEqual(conn.resolved_hosts, {})

def test_tcp_connector_ctor_fingerprint_valid(self):
valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
self.assertEqual(conn.fingerprint, valid)

def test_tcp_connector_fingerprint_invalid(self):
invalid = b'\x00'
with self.assertRaises(ValueError):
aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)

def test_tcp_connector_fingerprint(self):
# The even-index fingerprints below are "expect success" cases
# for ./sample.crt.der, the cert presented by test_utils.run_server.
# The odd-index fingerprints are "expect fail" cases.
testcases = (
# md5
b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=',
b'\x00' * 16,

# sha1
b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9',
b'\x00' * 20,

# sha256
b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L'
b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b',
b'\x00' * 32,
)
for i, fingerprint in enumerate(testcases):
expect_fail = i % 2
conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False,
fingerprint=fingerprint)
with test_utils.run_server(self.loop, use_ssl=True) as httpd:
coro = client.request('get', httpd.url('method', 'get'),
connector=conn, loop=self.loop)
if expect_fail:
with self.assertRaises(FingerprintMismatch) as cm:
self.loop.run_until_complete(coro)
exc = cm.exception
self.assertEqual(exc.expected, fingerprint)
# the previous test case should be what we actually got
self.assertEqual(exc.got, testcases[i-1])
else:
# should not raise
self.loop.run_until_complete(coro)

def test_tcp_connector_clear_resolved_hosts(self):
conn = aiohttp.TCPConnector(loop=self.loop)
info = object()
Expand Down Expand Up @@ -741,7 +789,7 @@ def test_connect(self, ClientRequestMock):
self.assertIs(conn._protocol, proto)

# resolve_host.assert_called_once_with('proxy.example.com', 80)
self.assertEqual(tr.mock_calls, [])
tr.get_extra_info.assert_called_once_with('sslcontext')

ClientRequestMock.assert_called_with(
'GET', 'http://proxy.example.com',
Expand Down Expand Up @@ -897,7 +945,7 @@ def test_https_connect(self, ClientRequestMock):
self.assertEqual(proxy_req.method, 'CONNECT')
self.assertEqual(proxy_req.path, 'www.python.org:443')
tr.pause_reading.assert_called_once_with()
tr.get_extra_info.assert_called_once_with('socket', default=None)
tr.get_extra_info.assert_called_with('socket', default=None)

@unittest.mock.patch('aiohttp.connector.ClientRequest')
def test_https_connect_runtime_error(self, ClientRequestMock):
Expand Down

0 comments on commit 0e1219a

Please sign in to comment.