Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept header parsing #2200

Merged
merged 15 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sanic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ class HeaderNotFound(InvalidUsage):
**Status**: 400 Bad Request
"""

status_code = 400
quiet = True

class InvalidHeader(InvalidUsage):
"""
**Status**: 400 Bad Request
"""


class ContentRangeError(SanicException):
Expand Down
144 changes: 144 additions & 0 deletions sanic/headers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import re

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import unquote

from sanic.exceptions import InvalidHeader
from sanic.helpers import STATUS_CODES


Expand Down Expand Up @@ -30,6 +33,121 @@
# For more information, consult ../tests/test_requests.py


def parse_arg_as_accept(f):
def func(self, other):
if not isinstance(other, Accept):
other = Accept.parse(other)
return f(self, other)

return func


class MediaType(str):
def __new__(cls, value: str):
return str.__new__(cls, value)

def __init__(self, value: str) -> None:
self.value = value
self.is_wildcard = self.check_if_wildcard(value)

def __eq__(self, other):
other_is_wildcard = (
other.is_wildcard
if isinstance(other, MediaType)
else self.check_if_wildcard(other)
)
other_value = other.value if isinstance(other, MediaType) else other
return (
self.value == other_value or self.is_wildcard or other_is_wildcard
)
ahopkins marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def check_if_wildcard(value):
return value == "*"


class Accept(str):
def __new__(cls, value: str, *args, **kwargs):
return str.__new__(cls, value)

def __init__(
self,
value: str,
type_: MediaType,
subtype: MediaType,
*,
q: str = "1.0",
**kwargs: str,
):
qvalue = float(q)
if qvalue > 1 or qvalue < 0:
raise InvalidHeader(
f"Accept header qvalue must be between 0 and 1, not: {qvalue}"
)
self.value = value
self.type_ = type_
self.subtype = subtype
self.qvalue = qvalue
self.params = kwargs

def _compare(self, other, method):
try:
return method(self.qvalue, other.qvalue)
except (AttributeError, TypeError):
return NotImplemented

@parse_arg_as_accept
def __lt__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s < o)

@parse_arg_as_accept
def __le__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s <= o)

@parse_arg_as_accept
def __eq__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s == o)

@parse_arg_as_accept
def __ge__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s >= o)

@parse_arg_as_accept
def __gt__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s > o)

@parse_arg_as_accept
def __ne__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s != o)

@parse_arg_as_accept
def match(self, other) -> bool:
return self.type_ == other.type_ and self.subtype == other.subtype

@classmethod
def parse(cls, raw: str) -> Accept:
invalid = False
mtype = raw.strip()

try:
media, *raw_params = mtype.split(";")
type_, subtype = media.split("/")
except ValueError:
invalid = True

if invalid or not type_ or not subtype:
raise InvalidHeader(f"Header contains invalid Accept value: {raw}")

params = dict(
[
(key.strip(), value.strip())
for key, value in (param.split("=", 1) for param in raw_params)
]
)

return cls(mtype, MediaType(type_), MediaType(subtype), **params)


def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values.

Expand Down Expand Up @@ -194,3 +312,29 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
ret += b"%b: %b\r\n" % h
ret += b"\r\n"
return ret


def _sort_accept_value(accept: Accept):
return (
accept.qvalue,
len(accept.params),
accept.subtype != "*",
accept.type_ != "*",
)


def parse_accept(accept: str) -> List[Accept]:
ahopkins marked this conversation as resolved.
Show resolved Hide resolved
"""Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
"""
media_types = accept.split(",")
accept_list: List[Accept] = []

for mtype in media_types:
if not mtype:
continue

accept_list.append(Accept.parse(mtype))

return sorted(accept_list, key=_sort_accept_value, reverse=True)
7 changes: 7 additions & 0 deletions sanic/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage
from sanic.headers import (
Accept,
Options,
parse_accept,
parse_content_header,
parse_forwarded,
parse_host,
Expand Down Expand Up @@ -296,6 +298,11 @@ def load_json(self, loads=json_loads):

return self.parsed_json

@property
def accept(self) -> List[Accept]:
accept_header = self.headers.getone("accept", "")
return parse_accept(accept_header)

ahopkins marked this conversation as resolved.
Show resolved Hide resolved
@property
def token(self):
"""Attempt to return the auth header token.
Expand Down
101 changes: 100 additions & 1 deletion tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from sanic import headers, text
from sanic.exceptions import PayloadTooLarge
from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http


Expand Down Expand Up @@ -182,3 +182,102 @@ def test_request_line(app):
)

assert request.request_line == b"GET / HTTP/1.1"


@pytest.mark.parametrize(
"raw",
(
"show/first, show/second",
"show/*, show/first",
"*/*, show/first",
"*/*, show/*",
"other/*; q=0.1, show/*; q=0.2",
"show/first; q=0.5, show/second; q=0.5",
"show/first; foo=bar, show/second; foo=bar",
"show/second, show/first; foo=bar",
"show/second; q=0.5, show/first; foo=bar; q=0.5",
"show/second; q=0.5, show/first; q=1.0",
"show/first, show/second; q=1.0",
),
)
def test_parse_accept_ordered_okay(raw):
ordered = headers.parse_accept(raw)
expected_subtype = (
"*" if all(q.subtype.is_wildcard for q in ordered) else "first"
)
assert ordered[0].type_ == "show"
assert ordered[0].subtype == expected_subtype


@pytest.mark.parametrize(
"raw",
(
"missing",
"missing/",
"/missing",
),
)
def test_bad_accept(raw):
with pytest.raises(InvalidHeader):
headers.parse_accept(raw)


def test_empty_accept():
assert headers.parse_accept("") == []


def test_wildcard_accept_set_ok():
accept = headers.parse_accept("*/*")[0]
assert accept.type_.is_wildcard
assert accept.subtype.is_wildcard

accept = headers.parse_accept("foo/bar")[0]
assert not accept.type_.is_wildcard
assert not accept.subtype.is_wildcard


def test_accept_parsed_against_str():
accept = headers.Accept.parse("foo/bar")
assert accept > "foo/bar; q=0.1"


def test_media_type_equality():
assert headers.MediaType("foo") == headers.MediaType("foo") == "foo"
assert headers.MediaType("foo") == headers.MediaType("*") == "*"
assert headers.MediaType("foo") != headers.MediaType("bar")
assert headers.MediaType("foo") != "bar"


@pytest.mark.parametrize(
"value,other",
(
("foo/bar", "foo/bar"),
("foo/bar", headers.Accept.parse("foo/bar")),
("foo/bar", "foo/*"),
("foo/bar", headers.Accept.parse("foo/*")),
("foo/bar", "*/*"),
("foo/bar", headers.Accept.parse("*/*")),
("foo/*", "foo/bar"),
("foo/*", headers.Accept.parse("foo/bar")),
("foo/*", "foo/*"),
("foo/*", headers.Accept.parse("foo/*")),
("foo/*", "*/*"),
("foo/*", headers.Accept.parse("*/*")),
("*/*", "foo/bar"),
("*/*", headers.Accept.parse("foo/bar")),
("*/*", "foo/*"),
("*/*", headers.Accept.parse("foo/*")),
("*/*", "*/*"),
("*/*", headers.Accept.parse("*/*")),
),
)
def test_accept_matching(value, other):
assert headers.Accept.parse(value).match(other)


@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*"))
def test_value_in_accept(value):
acceptable = headers.parse_accept(value)
assert "foo/bar" in acceptable
assert "foo/*" in acceptable
assert "*/*" in acceptable
36 changes: 36 additions & 0 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,39 @@ async def get(request):
assert resp.json["client"] == "[::1]"
assert resp.json["client_ip"] == "::1"
assert request.ip == "::1"


def test_request_accept():
app = Sanic("req-generator")

@app.get("/")
async def get(request):
return response.empty()

request, _ = app.test_client.get(
"/",
headers={
"Accept": "text/*, text/plain, text/plain;format=flowed, */*"
},
)
assert request.accept == [
"text/plain;format=flowed",
"text/plain",
"text/*",
"*/*",
]

request, _ = app.test_client.get(
"/",
headers={
"Accept": (
"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"
)
},
)
assert request.accept == [
"text/html",
"text/x-c",
"text/x-dvi; q=0.8",
"text/plain; q=0.5",
]