Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow int Subclasses in Query #492

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/492.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow for int and float subclasses in query, while still denying bool.
72 changes: 58 additions & 14 deletions tests/test_update_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,34 +155,78 @@ def test_with_query_sequence_invalid_use(query):
url.with_query(query)


def test_with_query_non_str():
url = URL("http://example.com")
with pytest.raises(TypeError):
url.with_query({"a": 1.1})
class _CStr(str):
pass


class _EmptyStrEr:
def __str__(self):
return ""


class _CInt(int, _EmptyStrEr):
pass


def test_with_query_bool():
class _CFloat(float, _EmptyStrEr):
pass


@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param("1", "1", id="str"),
pytest.param(_CStr("1"), "1", id="custom str"),
pytest.param(1, "1", id="int"),
pytest.param(_CInt(1), "1", id="custom int"),
pytest.param(1.1, "1.1", id="float"),
pytest.param(_CFloat(1.1), "1.1", id="custom float"),
],
)
def test_with_query_valid_type(value, expected):
url = URL("http://example.com")
with pytest.raises(TypeError):
url.with_query({"a": True})
expected = "http://example.com/?a={expected}".format_map(locals())
assert str(url.with_query({"a": value})) == expected


def test_with_query_none():
@pytest.mark.parametrize(
("value"),
[
pytest.param(True, id="bool"),
pytest.param(None, id="none"),
pytest.param(float("inf"), id="non-finite float"),
],
)
def test_with_query_invalid_type(value):
url = URL("http://example.com")
with pytest.raises(TypeError):
url.with_query({"a": None})
url.with_query({"a": value})


def test_with_query_list_non_str():
@pytest.mark.parametrize(
("value", "expected"),
[
pytest.param("1", "1", id="str"),
pytest.param(_CStr("1"), "1", id="custom str"),
pytest.param(1, "1", id="int"),
pytest.param(_CInt(1), "1", id="custom int"),
pytest.param(1.1, "1.1", id="float"),
pytest.param(_CFloat(1.1), "1.1", id="custom float"),
],
)
def test_with_query_list_valid_type(value, expected):
url = URL("http://example.com")
with pytest.raises(TypeError):
url.with_query([("a", 1.0)])
expected = "http://example.com/?a={expected}".format_map(locals())
assert str(url.with_query([("a", value)])) == expected


def test_with_query_list_bool():
@pytest.mark.parametrize(
("value"), [pytest.param(True, id="bool"), pytest.param(None, id="none")]
)
def test_with_query_list_invalid_type(value):
url = URL("http://example.com")
with pytest.raises(TypeError):
url.with_query([("a", False)])
url.with_query([("a", value)])


def test_with_query_multidict():
Expand Down
15 changes: 10 additions & 5 deletions yarl/_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from multidict import MultiDict, MultiDictProxy
import idna

import math


from ._quoting import _Quoter, _Unquoter

Expand Down Expand Up @@ -902,14 +904,17 @@ def _query_seq_pairs(cls, quoter, pairs):

@staticmethod
def _query_var(v):
if isinstance(v, str):
cls = type(v)
if issubclass(cls, str):
return v
if type(v) is int: # no subclasses like bool
return str(v)
if issubclass(cls, (int, float)) and cls is not bool:
if not math.isfinite(v):
raise TypeError("Value should be finite")
return int.__str__(v) # same as float.__str__
raise TypeError(
"Invalid variable type: value "
"should be str or int, got {!r} "
"of type {}".format(v, type(v))
"should be str, int or float, got {!r} "
"of type {}".format(v, cls)
)

def _get_str_query(self, *args, **kwargs):
Expand Down