Skip to content

Commit

Permalink
Support plugging in a new way of parsing URLs
Browse files Browse the repository at this point in the history
This way people can customize it to their liking, as there a lot of
opinions about this, as evidenced by the comments on GH-34.

The default parsing is still the same as before, so new version don't
break existing code. But the user has the option of passing in a
settings object, which has a `urlparse` attribute that can be set to a
custom function that processes the URL and splits into a `sockpath` and
a `reqpath`.
  • Loading branch information
msabramo committed Dec 27, 2021
1 parent 8449bc0 commit 49b9339
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 21 deletions.
24 changes: 19 additions & 5 deletions requests_unixsocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import requests
import sys

import requests
from requests.compat import urlparse, unquote

from .adapters import UnixAdapter

DEFAULT_SCHEME = 'http+unix://'

def default_urlparse(url):
parsed_url = urlparse(url)
return UnixAdapter.Settings.ParseResult(
sockpath=unquote(parsed_url.netloc),
reqpath=parsed_url.path + '?' + parsed_url.query,
)


default_scheme = 'http+unix://'
default_settings = UnixAdapter.Settings(urlparse=default_urlparse)


class Session(requests.Session):
def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs):
def __init__(self, url_scheme=default_scheme, settings=None,
*args, **kwargs):
super(Session, self).__init__(*args, **kwargs)
self.mount(url_scheme, UnixAdapter())
self.settings = settings or default_settings
self.mount(url_scheme, UnixAdapter(settings=self.settings))


class monkeypatch(object):
def __init__(self, url_scheme=DEFAULT_SCHEME):
def __init__(self, url_scheme=default_scheme):
self.session = Session()
requests = self._get_global_requests_module()

Expand Down
46 changes: 30 additions & 16 deletions requests_unixsocket/adapters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import socket
from collections import namedtuple

from requests.adapters import HTTPAdapter
from requests.compat import urlparse, unquote
from requests.compat import urlparse

try:
import http.client as httplib
Expand All @@ -18,16 +19,12 @@
# https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py
class UnixHTTPConnection(httplib.HTTPConnection, object):

def __init__(self, unix_socket_url, timeout=60):
"""Create an HTTP connection to a unix domain socket
:param unix_socket_url: A URL with a scheme of 'http+unix' and the
netloc is a percent-encoded path to a unix domain socket. E.g.:
'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
"""
def __init__(self, url, timeout=60, settings=None):
"""Create an HTTP connection to a unix domain socket"""
super(UnixHTTPConnection, self).__init__('localhost', timeout=timeout)
self.unix_socket_url = unix_socket_url
self.url = url
self.timeout = timeout
self.settings = settings
self.sock = None

def __del__(self): # base class does not have d'tor
Expand All @@ -37,27 +34,40 @@ def __del__(self): # base class does not have d'tor
def connect(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
sock.connect(socket_path)
sockpath = self.settings.urlparse(self.url).sockpath
sock.connect(sockpath)
self.sock = sock


class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):

def __init__(self, socket_path, timeout=60):
def __init__(self, socket_path, timeout=60, settings=None):
super(UnixHTTPConnectionPool, self).__init__(
'localhost', timeout=timeout)
self.socket_path = socket_path
self.timeout = timeout
self.settings = settings

def _new_conn(self):
return UnixHTTPConnection(self.socket_path, self.timeout)
return UnixHTTPConnection(
url=self.socket_path,
timeout=self.timeout,
settings=self.settings,
)


class UnixAdapter(HTTPAdapter):
class Settings(object):
class ParseResult(namedtuple('ParseResult', 'sockpath reqpath')):
pass

def __init__(self, urlparse=None):
self.urlparse = urlparse

def __init__(self, timeout=60, pool_connections=25, *args, **kwargs):
def __init__(self, timeout=60, pool_connections=25, settings=None,
*args, **kwargs):
super(UnixAdapter, self).__init__(*args, **kwargs)
self.settings = settings
self.timeout = timeout
self.pools = urllib3._collections.RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
Expand All @@ -76,13 +86,17 @@ def get_connection(self, url, proxies=None):
if pool:
return pool

pool = UnixHTTPConnectionPool(url, self.timeout)
pool = UnixHTTPConnectionPool(
socket_path=url,
settings=self.settings,
timeout=self.timeout,
)
self.pools[url] = pool

return pool

def request_url(self, request, proxies):
return request.path_url
return self.settings.urlparse(request.url).reqpath

def close(self):
self.pools.clear()
87 changes: 87 additions & 0 deletions requests_unixsocket/tests/test_requests_unixsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
"""Tests for requests_unixsocket"""

import logging
import os
import stat

import pytest
import requests
from requests.compat import urlparse

import requests_unixsocket
from requests_unixsocket.testutils import UnixSocketServerThread
Expand All @@ -15,6 +18,35 @@
logger = logging.getLogger(__name__)


def is_socket(path):
try:
mode = os.stat(path).st_mode
return stat.S_ISSOCK(mode)
except OSError:
return False


def get_sock_prefix(path):
"""Keep going up directory tree until we find a socket"""

sockpath = path
reqpath_parts = []

while not is_socket(sockpath):
sockpath, tail = os.path.split(sockpath)
reqpath_parts.append(tail)

return requests_unixsocket.UnixAdapter.Settings.ParseResult(
sockpath=sockpath,
reqpath='/' + os.path.join(*reversed(reqpath_parts)),
)


alt_settings_1 = requests_unixsocket.UnixAdapter.Settings(
urlparse=lambda url: get_sock_prefix(urlparse(url).path),
)


def test_unix_domain_adapter_ok():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
Expand All @@ -41,6 +73,34 @@ def test_unix_domain_adapter_ok():
assert r.text == 'Hello world!'


def test_unix_domain_adapter_alt_settings_1_ok():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session(
url_scheme='http+unix://',
settings=alt_settings_1,
)
url = 'http+unix://localhost%s/path/to/page' % usock_thread.usock

for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'


def test_unix_domain_adapter_url_with_query_params():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
Expand Down Expand Up @@ -69,6 +129,33 @@ def test_unix_domain_adapter_url_with_query_params():
assert r.text == 'Hello world!'


def test_unix_domain_adapter_url_with_fragment():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = ('http+unix://%s'
'/containers/nginx/logs#some-fragment' % urlencoded_usock)

for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'


def test_unix_domain_adapter_connection_error():
session = requests_unixsocket.Session('http+unix://')

Expand Down

0 comments on commit 49b9339

Please sign in to comment.