Skip to content

Commit

Permalink
Shorten too long test UNIX socket path
Browse files Browse the repository at this point in the history
Fixes #3572

Different OS kernels have different fs path length limitations
for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending
on its version. For most of the BSDs (Open, Free, macOS) it's
mostly 104 but sometimes it can be down to 100.

Ref: https://unix.stackexchange.com/a/367012/27133

This change implements a flexible socket path generator fixture
which guarantees that it's fit into the memory space allocated
by the kernel of the current OS runtime.

(cherry picked from commit 8e9e39b)
  • Loading branch information
webknjaz committed Sep 7, 2022
1 parent 019e1ab commit c6a3779
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 52 deletions.
111 changes: 95 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import hashlib
import pathlib
import shutil
import os
import socket
import ssl
import sys
import tempfile
import uuid
from hashlib import md5, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import uuid4

import pytest

Expand All @@ -24,15 +25,14 @@
pytest_plugins = ["aiohttp.pytest_plugin", "pytester"]


@pytest.fixture
def shorttmpdir():
# Provides a temporary directory with a shorter file system path than the
# tmpdir fixture.
tmpdir = pathlib.Path(tempfile.mkdtemp())
yield tmpdir
# str(tmpdir) is required, Python 3.5 doesn't have __fspath__
# concept
shutil.rmtree(str(tmpdir), ignore_errors=True)
IS_HPUX = sys.platform.startswith("hp-ux")
"""Specifies whether the current runtime is HP-UX."""
IS_LINUX = sys.platform.startswith("linux")
"""Specifies whether the current runtime is HP-UX."""
IS_UNIX = hasattr(socket, "AF_UNIX")
"""Specifies whether the current runtime is *NIX."""

needs_unix = pytest.mark.skipif(not IS_UNIX, reason="requires UNIX sockets")


@pytest.fixture
Expand Down Expand Up @@ -85,12 +85,91 @@ def tls_certificate_pem_bytes(tls_certificate):
@pytest.fixture
def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes):
tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode())
return hashlib.sha256(tls_cert_der).digest()
return sha256(tls_cert_der).digest()


@pytest.fixture
def unix_sockname(tmp_path, tmp_path_factory):
"""Generate an fs path to the UNIX domain socket for testing.
N.B. Different OS kernels have different fs path length limitations
for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending
on its version. For for most of the BSDs (Open, Free, macOS) it's
mostly 104 but sometimes it can be down to 100.
Ref: https://github.com/aio-libs/aiohttp/issues/3572
"""
if not IS_UNIX:
pytest.skip("requires UNIX sockets")

max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100
"""Amount of bytes allocated for the UNIX socket path by OS kernel.
Ref: https://unix.stackexchange.com/a/367012/27133
"""

sock_file_name = "unix.sock"
unique_prefix = f"{uuid4()!s}-"
unique_prefix_len = len(unique_prefix.encode())

root_tmp_dir = Path("/tmp").resolve()
os_tmp_dir = Path(os.getenv("TMPDIR", "/tmp")).resolve()
original_base_tmp_path = Path(
str(tmp_path_factory.getbasetemp()),
).resolve()

original_base_tmp_path_hash = md5(
str(original_base_tmp_path).encode(),
).hexdigest()

def make_tmp_dir(base_tmp_dir):
return TemporaryDirectory(
dir=str(base_tmp_dir),
prefix="pt-",
suffix=f"-{original_base_tmp_path_hash!s}",
)

def assert_sock_fits(sock_path):
sock_path_len = len(sock_path.encode())
# exit-check to verify that it's correct and simplify debugging
# in the future
assert sock_path_len <= max_sock_len, (
"Suggested UNIX socket ({sock_path}) is {sock_path_len} bytes "
"long but the current kernel only has {max_sock_len} bytes "
"allocated to hold it so it must be shorter. "
"See https://github.com/aio-libs/aiohttp/issues/3572 "
"for more info."
).format_map(locals())

paths = original_base_tmp_path, os_tmp_dir, root_tmp_dir
unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]]
paths_num = len(unique_paths)

for num, tmp_dir_path in enumerate(paths, 1):
with make_tmp_dir(tmp_dir_path) as tmpd:
tmpd = Path(tmpd).resolve()
sock_path = str(tmpd / sock_file_name)
sock_path_len = len(sock_path.encode())

if num >= paths_num:
# exit-check to verify that it's correct and simplify
# debugging in the future
assert_sock_fits(sock_path)

if sock_path_len <= max_sock_len:
if max_sock_len - sock_path_len >= unique_prefix_len:
# If we're lucky to have extra space in the path,
# let's also make it more unique
sock_path = str(tmpd / "".join((unique_prefix, sock_file_name)))
# Double-checking it:
assert_sock_fits(sock_path)
yield sock_path
return


@pytest.fixture
def pipe_name():
name = rf"\\.\pipe\{uuid.uuid4().hex}"
name = rf"\\.\pipe\{uuid4().hex}"
return name


Expand Down
12 changes: 3 additions & 9 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest import mock

import pytest
from conftest import needs_unix
from yarl import URL

import aiohttp
Expand Down Expand Up @@ -43,12 +44,6 @@ def ssl_key():
return ConnectionKey("localhost", 80, True, None, None, None, None)


@pytest.fixture
def unix_sockname(shorttmpdir):
sock_path = shorttmpdir / "socket.sock"
return str(sock_path)


@pytest.fixture
def unix_server(loop, unix_sockname):
runners = []
Expand Down Expand Up @@ -1918,7 +1913,7 @@ async def handler(request):
assert r.status == 200


@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket")
@needs_unix
async def test_unix_connector_not_found(loop) -> None:
connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop)

Expand All @@ -1927,7 +1922,7 @@ async def test_unix_connector_not_found(loop) -> None:
await connector.connect(req, None, ClientTimeout())


@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix socket")
@needs_unix
async def test_unix_connector_permission(loop) -> None:
loop.create_unix_connection = make_mocked_coro(raise_exception=PermissionError())
connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex, loop=loop)
Expand Down Expand Up @@ -2086,7 +2081,6 @@ async def handler(request):
await conn.close()


@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets")
async def test_unix_connector(unix_server, unix_sockname) -> None:
async def handler(request):
return web.Response()
Expand Down
37 changes: 17 additions & 20 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from uuid import uuid4

import pytest
from conftest import IS_UNIX, needs_unix

from aiohttp import web
from aiohttp.helpers import PY_37
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

# Test for features of OS' socket support
_has_unix_domain_socks = hasattr(socket, "AF_UNIX")
if _has_unix_domain_socks:
if IS_UNIX:
_abstract_path_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
_abstract_path_sock.bind(b"\x00" + uuid4().hex.encode("ascii")) # type: ignore
Expand All @@ -37,10 +37,7 @@
skip_if_no_abstract_paths = pytest.mark.skipif(
_abstract_path_failed, reason="Linux-style abstract paths are not supported."
)
skip_if_no_unix_socks = pytest.mark.skipif(
not _has_unix_domain_socks, reason="Unix domain sockets are not supported"
)
del _has_unix_domain_socks, _abstract_path_failed
del IS_UNIX, _abstract_path_failed

HAS_IPV6 = socket.has_ipv6
if HAS_IPV6:
Expand Down Expand Up @@ -509,38 +506,38 @@ def test_run_app_custom_backlog_unix(patched_loop) -> None:
)


@skip_if_no_unix_socks
def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None:
def test_run_app_http_unix_socket(patched_loop, unix_sockname) -> None:
app = web.Application()

sock_path = str(shorttmpdir / "socket.sock")
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, print=printer, loop=patched_loop)
web.run_app(app, path=unix_sockname, print=printer, loop=patched_loop)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=None, backlog=128
mock.ANY, unix_sockname, ssl=None, backlog=128
)
assert f"http://unix:{sock_path}:" in printer.call_args[0][0]
assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0]


@skip_if_no_unix_socks
def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None:
def test_run_app_https_unix_socket(patched_loop, unix_sockname) -> None:
app = web.Application()

sock_path = str(shorttmpdir / "socket.sock")
ssl_context = ssl.create_default_context()
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(
app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop
app,
path=unix_sockname,
ssl_context=ssl_context,
print=printer,
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=ssl_context, backlog=128
mock.ANY, unix_sockname, ssl=ssl_context, backlog=128
)
assert f"https://unix:{sock_path}:" in printer.call_args[0][0]
assert f"https://unix:{unix_sockname}:" in printer.call_args[0][0]


@skip_if_no_unix_socks
@needs_unix
@skip_if_no_abstract_paths
def test_run_app_abstract_linux_socket(patched_loop) -> None:
sock_path = b"\x00" + uuid4().hex.encode("ascii")
Expand Down Expand Up @@ -592,7 +589,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop) -> None:
assert f"http://[::]:{port}" in printer.call_args[0][0]


@skip_if_no_unix_socks
@needs_unix
def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None:
app = web.Application()

Expand Down
10 changes: 3 additions & 7 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,17 @@ def test_non_app() -> None:
web.AppRunner(object())


@pytest.mark.skipif(
platform.system() == "Windows", reason="Unix socket support is required"
)
async def test_addresses(make_runner, shorttmpdir) -> None:
async def test_addresses(make_runner, unix_sockname) -> None:
_sock = get_unused_port_socket("127.0.0.1")
runner = make_runner()
await runner.setup()
tcp = web.SockSite(runner, _sock)
await tcp.start()
path = str(shorttmpdir / "tmp.sock")
unix = web.UnixSite(runner, path)
unix = web.UnixSite(runner, unix_sockname)
await unix.start()
actual_addrs = runner.addresses
expected_host, expected_post = _sock.getsockname()[:2]
assert actual_addrs == [(expected_host, expected_post), path]
assert actual_addrs == [(expected_host, expected_post), unix_sockname]


@pytest.mark.skipif(
Expand Down

0 comments on commit c6a3779

Please sign in to comment.