diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 65261664..acfe87e4 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -380,6 +380,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, passfile=passfile) addrs = [] + have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'): # UNIX socket name @@ -389,6 +390,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: # TCP host/port addrs.append((h, p)) + have_tcp_addrs = True if not addrs: raise ValueError( @@ -397,6 +399,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if ssl is None: ssl = os.getenv('PGSSLMODE') + if ssl is None and have_tcp_addrs: + ssl = 'prefer' + # ssl_is_advisory is only allowed to come from the sslmode parameter. ssl_is_advisory = None if isinstance(ssl, str): @@ -435,14 +440,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if sslmode <= SSLMODES['require']: ssl.verify_mode = ssl_module.CERT_NONE ssl_is_advisory = sslmode <= SSLMODES['prefer'] - - if ssl: - for addr in addrs: - if isinstance(addr, str): - # UNIX socket - raise exceptions.InterfaceError( - '`ssl` parameter can only be enabled for TCP addresses, ' - 'got a UNIX socket path: {!r}'.format(addr)) + elif ssl is True: + ssl = ssl_module.create_default_context() if server_settings is not None and ( not isinstance(server_settings, dict) or @@ -542,9 +541,6 @@ def connection_lost(self, exc): async def _create_ssl_connection(protocol_factory, host, port, *, loop, ssl_context, ssl_is_advisory=False): - if ssl_context is True: - ssl_context = ssl_module.create_default_context() - tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, ssl_context, ssl_is_advisory), @@ -625,7 +621,6 @@ async def _connect_addr( if isinstance(addr, str): # UNIX socket - assert not params.ssl connector = loop.create_unix_connection(proto_factory, addr) elif params.ssl: connector = _create_ssl_connection( diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 5942d920..563234dd 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1869,7 +1869,28 @@ async def connect(dsn=None, *, Pass ``True`` or an `ssl.SSLContext `_ instance to require an SSL connection. If ``True``, a default SSL context returned by `ssl.create_default_context() `_ - will be used. + will be used. The value can also be one of the following strings: + + - ``'disable'`` - SSL is disabled (equivalent to ``False``) + - ``'prefer'`` - try SSL first, fallback to non-SSL connection + if SSL connection fails + - ``'allow'`` - currently equivalent to ``'prefer'`` + - ``'require'`` - only try an SSL connection. Certificate + verifiction errors are ignored + - ``'verify-ca'`` - only try an SSL connection, and verify + that the server certificate is issued by a trusted certificate + authority (CA) + - ``'verify-full'`` - only try an SSL connection, verify + that the server certificate is issued by a trusted CA and + that the requested server host name matches that in the + certificate. + + The default is ``'prefer'``: try an SSL connection and fallback to + non-SSL connection if that fails. + + .. note:: + + *ssl* is ignored for Unix domain socket communication. :param dict server_settings: An optional dict of server runtime parameters. Refer to @@ -1926,6 +1947,9 @@ async def connect(dsn=None, *, .. versionchanged:: 0.22.0 Added the *record_class* parameter. + .. versionchanged:: 0.22.0 + The *ssl* argument now defaults to ``'prefer'``. + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context diff --git a/tests/test_connect.py b/tests/test_connect.py index 116b8ad9..af927426 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase): 'result': ([('host', 123)], { 'user': 'user', 'password': 'passw', - 'database': 'testdb'}) + 'database': 'testdb', + 'ssl': True, + 'ssl_is_advisory': True}) }, { @@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase): 'user': 'user3', 'password': '123123', 'database': 'abcdef', - 'ssl': ssl.SSLContext, + 'ssl': True, 'ssl_is_advisory': True}) }, @@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase): 'user': 'me', 'password': 'ask', 'database': 'db', - 'ssl': ssl.SSLContext, + 'ssl': True, 'ssl_is_advisory': False}) }, @@ -545,6 +547,7 @@ class TestConnectParams(tb.TestCase): { 'user': 'user', 'database': 'user', + 'ssl': None } ) }, @@ -574,7 +577,9 @@ class TestConnectParams(tb.TestCase): ('localhost', 5433) ], { 'user': 'spam', - 'database': 'db' + 'database': 'db', + 'ssl': True, + 'ssl_is_advisory': True } ) }, @@ -617,7 +622,7 @@ def run_testcase(self, testcase): password = testcase.get('password') passfile = testcase.get('passfile') database = testcase.get('database') - ssl = testcase.get('ssl') + sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') expected = testcase.get('result') @@ -640,21 +645,26 @@ def run_testcase(self, testcase): addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, - passfile=passfile, database=database, ssl=ssl, + passfile=passfile, database=database, ssl=sslmode, connect_timeout=None, server_settings=server_settings) - params = {k: v for k, v in params._asdict().items() - if v is not None} + params = { + k: v for k, v in params._asdict().items() + if v is not None or (expected is not None and k in expected[1]) + } + + if isinstance(params.get('ssl'), ssl.SSLContext): + params['ssl'] = True result = (addrs, params) if expected is not None: - for k, v in expected[1].items(): - # If `expected` contains a type, allow that to "match" any - # instance of that type tyat `result` may contain. We need - # this because different SSLContexts don't compare equal. - if isinstance(v, type) and isinstance(result[1].get(k), v): - result[1][k] = v + if 'ssl' not in expected[1]: + # Avoid the hassle of specifying the default SSL mode + # unless explicitly tested for. + params.pop('ssl', None) + params.pop('ssl_is_advisory', None) + self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) def test_test_connect_params_environ(self): @@ -1063,16 +1073,6 @@ async def verify_fails(sslmode): await verify_fails('verify-ca') await verify_fails('verify-full') - async def test_connection_ssl_unix(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_context.load_verify_locations(SSL_CA_CERT_FILE) - - with self.assertRaisesRegex(asyncpg.InterfaceError, - 'can only be enabled for TCP addresses'): - await self.connect( - host='/tmp', - ssl=ssl_context) - async def test_connection_implicit_host(self): conn_spec = self.get_connection_spec() con = await asyncpg.connect(