Skip to content

Commit

Permalink
Shorten too long test UNIX socket path
Browse files Browse the repository at this point in the history
Fixes aio-libs#3572

Different OS kernels have different fs path length limitations
for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending
on its version. For most of the BSDs (Open, Free, macOS) it's
mostly 104 but sometimes it can be down to 100.

Ref: https://unix.stackexchange.com/a/367012/27133

This change implements a flexible socket path generator fixture
which guarantees that it's fit into the memory space allocated
by the kernel of the current OS runtime.
  • Loading branch information
webknjaz authored Jun 11, 2019
1 parent b80fec6 commit 8e9e39b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 19 deletions.
100 changes: 98 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import hashlib
import os
import socket
import ssl
import sys
from hashlib import md5, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import uuid4

import pytest
import trustme

pytest_plugins = ['aiohttp.pytest_plugin', 'pytester']

IS_HPUX = sys.platform.startswith('hp-ux')
"""Specifies whether the current runtime is HP-UX."""
IS_LINUX = sys.platform.startswith('linux')
"""Specifies whether the current runtime is HP-UX."""
IS_UNIX = hasattr(socket, 'AF_UNIX')
"""Specifies whether the current runtime is *NIX."""

needs_unix = pytest.mark.skipif(not IS_UNIX, reason='requires UNIX sockets')


@pytest.fixture
def tls_certificate_authority():
Expand Down Expand Up @@ -55,4 +70,85 @@ def tls_certificate_pem_bytes(tls_certificate):
@pytest.fixture
def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes):
tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode())
return hashlib.sha256(tls_cert_der).digest()
return sha256(tls_cert_der).digest()


@pytest.fixture
def unix_sockname(tmp_path, tmp_path_factory):
"""Generate an fs path to the UNIX domain socket for testing.
N.B. Different OS kernels have different fs path length limitations
for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending
on its version. For for most of the BSDs (Open, Free, macOS) it's
mostly 104 but sometimes it can be down to 100.
Ref: https://github.com/aio-libs/aiohttp/issues/3572
"""
if not IS_UNIX:
pytest.skip('requires UNIX sockets')

max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100
"""Amount of bytes allocated for the UNIX socket path by OS kernel.
Ref: https://unix.stackexchange.com/a/367012/27133
"""

sock_file_name = 'unix.sock'
unique_prefix = '{!s}-'.format(uuid4())
unique_prefix_len = len(unique_prefix.encode())

root_tmp_dir = Path('/tmp').resolve()
os_tmp_dir = Path(os.getenv('TMPDIR', '/tmp')).resolve()
original_base_tmp_path = Path(
str(tmp_path_factory.getbasetemp()),
).resolve()

original_base_tmp_path_hash = md5(
str(original_base_tmp_path).encode(),
).hexdigest()

def make_tmp_dir(base_tmp_dir):
return TemporaryDirectory(
dir=str(base_tmp_dir),
prefix='pt-',
suffix='-{!s}'.format(original_base_tmp_path_hash),
)

def assert_sock_fits(sock_path):
sock_path_len = len(sock_path.encode())
# exit-check to verify that it's correct and simplify debugging
# in the future
assert sock_path_len <= max_sock_len, (
'Suggested UNIX socket ({sock_path}) is {sock_path_len} bytes '
'long but the current kernel only has {max_sock_len} bytes '
'allocated to hold it so it must be shorter. '
'See https://github.com/aio-libs/aiohttp/issues/3572 '
'for more info.'
).format_map(locals())

paths = original_base_tmp_path, os_tmp_dir, root_tmp_dir
unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]]
paths_num = len(unique_paths)

for num, tmp_dir_path in enumerate(paths, 1):
with make_tmp_dir(tmp_dir_path) as tmpd:
tmpd = Path(tmpd).resolve()
sock_path = str(tmpd / sock_file_name)
sock_path_len = len(sock_path.encode())

if num >= paths_num:
# exit-check to verify that it's correct and simplify
# debugging in the future
assert_sock_fits(sock_path)

if sock_path_len <= max_sock_len:
if max_sock_len - sock_path_len >= unique_prefix_len:
# If we're lucky to have extra space in the path,
# let's also make it more unique
sock_path = str(
tmpd / ''.join((unique_prefix, sock_file_name))
)
# Double-checking it:
assert_sock_fits(sock_path)
yield sock_path
return
14 changes: 3 additions & 11 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from aiohttp.helpers import PY_37
from aiohttp.test_utils import make_mocked_coro, unused_port
from aiohttp.tracing import Trace
from conftest import needs_unix


@pytest.fixture()
Expand All @@ -42,11 +43,6 @@ def ssl_key():
return ConnectionKey('localhost', 80, True, None, None, None, None)


@pytest.fixture
def unix_sockname(tmp_path):
return str(tmp_path / 'socket.sock')


@pytest.fixture
def unix_server(loop, unix_sockname):
runners = []
Expand Down Expand Up @@ -1956,8 +1952,7 @@ async def handler(request):
assert r.status == 200


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix socket")
@needs_unix
async def test_unix_connector_not_found(loop) -> None:
connector = aiohttp.UnixConnector('/' + uuid.uuid4().hex, loop=loop)

Expand All @@ -1968,8 +1963,7 @@ async def test_unix_connector_not_found(loop) -> None:
await connector.connect(req, None, ClientTimeout())


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix socket")
@needs_unix
async def test_unix_connector_permission(loop) -> None:
loop.create_unix_connection = make_mocked_coro(
raise_exception=PermissionError())
Expand Down Expand Up @@ -2094,8 +2088,6 @@ async def handler(request):
conn.close()


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason='requires UNIX sockets')
async def test_unix_connector(unix_server, unix_sockname) -> None:
async def handler(request):
return web.Response()
Expand Down
9 changes: 3 additions & 6 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,14 @@ def test_non_app() -> None:
web.AppRunner(object())


@pytest.mark.skipif(platform.system() == "Windows",
reason="Unix socket support is required")
async def test_addresses(make_runner, tmpdir) -> None:
async def test_addresses(make_runner, unix_sockname) -> None:
_sock = get_unused_port_socket('127.0.0.1')
runner = make_runner()
await runner.setup()
tcp = web.SockSite(runner, _sock)
await tcp.start()
path = str(tmpdir / 'tmp.sock')
unix = web.UnixSite(runner, path)
unix = web.UnixSite(runner, unix_sockname)
await unix.start()
actual_addrs = runner.addresses
expected_host, expected_post = _sock.getsockname()[:2]
assert actual_addrs == [(expected_host, expected_post), path]
assert actual_addrs == [(expected_host, expected_post), unix_sockname]

0 comments on commit 8e9e39b

Please sign in to comment.