Skip to content

Commit

Permalink
Handle data={"key": [None|int|float|bool]} cases. (#1539)
Browse files Browse the repository at this point in the history
* Fix Content-Length for unicode file contents with multipart

* Handle bool and None cases for URLEncoded data

* Handle int, float, bool, and None for multipart or urlencoded data

* Update httpx/_utils.py

Co-authored-by: Florimond Manca <[email protected]>

Co-authored-by: Florimond Manca <[email protected]>
  • Loading branch information
tomchristie and florimondmanca authored Mar 26, 2021
1 parent c75ddc2 commit c26425a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 11 deletions.
9 changes: 8 additions & 1 deletion httpx/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RequestFiles,
ResponseContent,
)
from ._utils import primitive_value_to_str


class PlainByteStream:
Expand Down Expand Up @@ -106,7 +107,13 @@ def encode_content(
def encode_urlencoded_data(
data: dict,
) -> Tuple[Dict[str, str], ByteStream]:
body = urlencode(data, doseq=True).encode("utf-8")
plain_data = []
for key, value in data.items():
if isinstance(value, (list, tuple)):
plain_data.extend([(key, primitive_value_to_str(item)) for item in value])
else:
plain_data.append((key, primitive_value_to_str(value)))
body = urlencode(plain_data, doseq=True).encode("utf-8")
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
headers = {"Content-Length": content_length, "Content-Type": content_type}
Expand Down
6 changes: 3 additions & 3 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
normalize_header_value,
obfuscate_sensitive_headers,
parse_header_links,
str_query_param,
primitive_value_to_str,
)


Expand Down Expand Up @@ -450,8 +450,8 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
else:
items = flatten_queryparams(value)

self._list = [(str(k), str_query_param(v)) for k, v in items]
self._dict = {str(k): str_query_param(v) for k, v in items}
self._list = [(str(k), primitive_value_to_str(v)) for k, v in items]
self._dict = {str(k): primitive_value_to_str(v) for k, v in items}

def keys(self) -> typing.KeysView:
return self._dict.keys()
Expand Down
13 changes: 9 additions & 4 deletions httpx/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
format_form_param,
guess_content_type,
peek_filelike_length,
primitive_value_to_str,
to_bytes,
)

Expand All @@ -17,17 +18,21 @@ class DataField:
A single form field item, within a multipart form field.
"""

def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
def __init__(
self, name: str, value: typing.Union[str, bytes, int, float, None]
) -> None:
if not isinstance(name, str):
raise TypeError(
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
)
if not isinstance(value, (str, bytes)):
if value is not None and not isinstance(value, (str, bytes, int, float)):
raise TypeError(
f"Invalid type for value. Expected str or bytes, got {type(value)}: {value!r}"
f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
)
self.name = name
self.value = value
self.value: typing.Union[str, bytes] = (
value if isinstance(value, bytes) else primitive_value_to_str(value)
)

def render_headers(self) -> bytes:
if not hasattr(self, "_headers"):
Expand Down
4 changes: 2 additions & 2 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def normalize_header_value(
return value.encode(encoding or "ascii")


def str_query_param(value: "PrimitiveData") -> str:
def primitive_value_to_str(value: "PrimitiveData") -> str:
"""
Coerce a primitive data type into a string value for query params.
Coerce a primitive data type into a string value.
Note that we prefer JSON-style 'true'/'false' for boolean values here.
"""
Expand Down
51 changes: 51 additions & 0 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,57 @@ async def test_urlencoded_content():
assert async_content == b"Hello=world%21"


@pytest.mark.asyncio
async def test_urlencoded_boolean():
headers, stream = encode_request(data={"example": True})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)

sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert headers == {
"Content-Length": "12",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"example=true"
assert async_content == b"example=true"


@pytest.mark.asyncio
async def test_urlencoded_none():
headers, stream = encode_request(data={"example": None})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)

sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert headers == {
"Content-Length": "8",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"example="
assert async_content == b"example="


@pytest.mark.asyncio
async def test_urlencoded_list():
headers, stream = encode_request(data={"example": ["a", 1, True]})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)

sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert headers == {
"Content-Length": "32",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"example=a&example=1&example=true"
assert async_content == b"example=a&example=1&example=true"


@pytest.mark.asyncio
async def test_multipart_files_content():
files = {"file": io.BytesIO(b"<file content>")}
Expand Down
6 changes: 5 additions & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_multipart_invalid_key(key):
assert repr(key) in str(e.value)


@pytest.mark.parametrize(("value"), (1, 2.3, None, [None, "abc"], {None: "abc"}))
@pytest.mark.parametrize(("value"), (object(), {"key": "value"}))
def test_multipart_invalid_value(value):
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))

Expand Down Expand Up @@ -104,6 +104,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
"b": b"C",
"c": ["11", "22", "33"],
"d": "",
"e": True,
"f": "",
}
files = {"file": ("name.txt", open(path, "rb"))}

Expand All @@ -120,6 +122,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
'--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n22\r\n'
'--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n33\r\n'
'--{0}\r\nContent-Disposition: form-data; name="d"\r\n\r\n\r\n'
'--{0}\r\nContent-Disposition: form-data; name="e"\r\n\r\ntrue\r\n'
'--{0}\r\nContent-Disposition: form-data; name="f"\r\n\r\n\r\n'
'--{0}\r\nContent-Disposition: form-data; name="file";'
' filename="name.txt"\r\n'
"Content-Type: text/plain\r\n\r\n<file content>\r\n"
Expand Down

0 comments on commit c26425a

Please sign in to comment.