diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index d743774596b4..ea72c3a18691 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -71,6 +71,41 @@ def _test_factory(fifo, start): assert len(set(id(obj) for obj in objects)) == len(threads) +class TestUnicodeRedirectTween: + def test_basic_redirect(self): + response = pretend.stub(location="/a/path/to/nowhere") + handler = pretend.call_recorder(lambda request: response) + registry = pretend.stub() + tween = warehouse.http.unicode_redirect_tween_factory( + handler, registry) + request = pretend.stub( + path="/A/pAtH/tO/nOwHeRe/", + ) + assert tween(request) == response + + def test_unicode_basic_redirect(self): + response = pretend.stub(location="/pypi/\u2603/json/") + handler = pretend.call_recorder(lambda request: response) + registry = pretend.stub() + tween = warehouse.http.unicode_redirect_tween_factory( + handler, registry) + request = pretend.stub( + path="/pypi/snowman/json/", + ) + assert tween(request).location == "/pypi/%E2%98%83/json/" + + def test_not_redirect(self): + response = pretend.stub(location=None) + handler = pretend.call_recorder(lambda request: response) + registry = pretend.stub() + tween = warehouse.http.unicode_redirect_tween_factory( + handler, registry) + request = pretend.stub( + path="/wu/tang/", + ) + assert tween(request) == response + + def test_includeme(): config = pretend.stub( registry=pretend.stub( @@ -79,10 +114,17 @@ def test_includeme(): add_request_method=pretend.call_recorder( lambda *args, **kwargs: None ), + add_tween=pretend.call_recorder( + lambda *args, **kwargs: None + ), ) warehouse.http.includeme(config) assert len(config.add_request_method.calls) == 1 + assert len(config.add_tween.calls) == 1 call = config.add_request_method.calls[0] assert isinstance(call.args[0], warehouse.http.ThreadLocalSessionFactory) assert call.kwargs == {"name": "http", "reify": True} + assert config.add_tween.calls == [ + pretend.call("warehouse.http.unicode_redirect_tween_factory") + ] diff --git a/warehouse/http.py b/warehouse/http.py index 8da234f5867f..c1bca4c1680c 100644 --- a/warehouse/http.py +++ b/warehouse/http.py @@ -13,6 +13,8 @@ import threading import requests +from urllib.parse import quote_plus + class ThreadLocalSessionFactory: def __init__(self, config=None): @@ -37,8 +39,25 @@ def __call__(self, request): return session +def unicode_redirect_tween_factory(handler, request): + + def unicode_redirect_tween(request): + response = handler(request) + if response.location: + try: + response.location.encode('ascii') + except UnicodeEncodeError: + response.location = '/'.join( + [quote_plus(x) for x in response.location.split('/')]) + + return response + + return unicode_redirect_tween + + def includeme(config): config.add_request_method( ThreadLocalSessionFactory(config.registry.settings.get("http")), name="http", reify=True ) + config.add_tween("warehouse.http.unicode_redirect_tween_factory")