Skip to content

Commit

Permalink
Consistent multidict methods (#1089)
Browse files Browse the repository at this point in the history
* Consistent multidict methods

* Consistent multidict methods and behaviour

* Update httpx/_models.py

Co-authored-by: Florimond Manca <[email protected]>

* Update httpx/_models.py

Co-authored-by: Florimond Manca <[email protected]>

Co-authored-by: Florimond Manca <[email protected]>
  • Loading branch information
tomchristie and florimondmanca authored Jul 31, 2020
1 parent dba83d4 commit 2ba9c1e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 24 deletions.
85 changes: 70 additions & 15 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json as jsonlib
import typing
import urllib.request
import warnings
from collections.abc import MutableMapping
from http.cookiejar import Cookie, CookieJar
from urllib.parse import parse_qsl, urlencode
Expand Down Expand Up @@ -240,26 +241,40 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
self._list = [(str(k), str_query_param(v)) for k, v in items]
self._dict = {str(k): str_query_param(v) for k, v in items}

def getlist(self, key: typing.Any) -> typing.List[str]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
return self._dict.keys()

def values(self) -> typing.ValuesView:
return self._dict.values()

def items(self) -> typing.ItemsView:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
"""
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.
"""
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
"""
Get a value from the query param for a given key. If the key occurs
more than once, then only the first value is returned.
"""
if key in self._dict:
return self._dict[key]
return default

def get_list(self, key: typing.Any) -> typing.List[str]:
"""
Get all values from the query param for a given key.
"""
return [item_value for item_key, item_value in self._list if item_key == key]

def update(self, params: QueryParamTypes = None) -> None:
if not params:
return
Expand Down Expand Up @@ -315,6 +330,13 @@ def __repr__(self) -> str:
query_string = str(self)
return f"{class_name}({query_string!r})"

def getlist(self, key: typing.Any) -> typing.List[str]:
message = (
"QueryParams.getlist() is pending deprecation. Use QueryParams.get_list()"
)
warnings.warn(message, PendingDeprecationWarning)
return self.get_list(key)


class Headers(typing.MutableMapping[str, str]):
"""
Expand All @@ -336,6 +358,14 @@ def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
for k, v in headers
]

self._dict = {} # type: typing.Dict[bytes, bytes]
for key, value in self._list:
if key in self._dict:
self._dict[key] = self._dict[key] + b", " + value
else:
self._dict[key] = value

self._encoding = encoding

@property
Expand Down Expand Up @@ -376,26 +406,47 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._list

def keys(self) -> typing.List[str]: # type: ignore
return [key.decode(self.encoding) for key, value in self._list]
return [key.decode(self.encoding) for key in self._dict.keys()]

def values(self) -> typing.List[str]: # type: ignore
return [value.decode(self.encoding) for key, value in self._list]
return [value.decode(self.encoding) for value in self._dict.values()]

def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
"""
Return a list of `(key, value)` pairs of headers. Concatenate headers
into a single comma seperated value when a key occurs multiple times.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for key, value in self._dict.items()
]

def multi_items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurences of the same key without concatenating into a single
comma seperated value.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for key, value in self._list
]

def get(self, key: str, default: typing.Any = None) -> typing.Any:
"""
Return a header value. If multiple occurences of the header occur
then concatenate them together with commas.
"""
try:
return self[key]
except KeyError:
return default

def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
"""
Return multiple header values.
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma seperated header
values are split into multiple return strings.
"""
get_header_key = key.lower().encode(self.encoding)

Expand Down Expand Up @@ -448,6 +499,8 @@ def __setitem__(self, key: str, value: str) -> None:
set_key = key.lower().encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8")

self._dict[set_key] = set_value

found_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == set_key:
Expand All @@ -468,22 +521,19 @@ def __delitem__(self, key: str) -> None:
"""
del_key = key.lower().encode(self.encoding)

del self._dict[del_key]

pop_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
if not pop_indexes:
raise KeyError(key)

for idx in reversed(pop_indexes):
del self._list[idx]

def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode(self.encoding)
for header_key, _ in self._list:
if header_key == get_header_key:
return True
return False
header_key = key.lower().encode(self.encoding)
return header_key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
Expand All @@ -503,14 +553,19 @@ def __repr__(self) -> str:
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"

as_list = list(obfuscate_sensitive_headers(self.items()))
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
as_dict = dict(as_list)

no_duplicate_keys = len(as_dict) == len(as_list)
if no_duplicate_keys:
return f"{class_name}({as_dict!r}{encoding_str})"
return f"{class_name}({as_list!r}{encoding_str})"

def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
message = "Headers.getlist() is pending deprecation. Use Headers.get_list()"
warnings.warn(message, PendingDeprecationWarning)
return self.get_list(key, split_commas=split_commas)


USER_AGENT = f"python-httpx/{__version__}"
ACCEPT_ENCODING = ", ".join(
Expand Down
17 changes: 9 additions & 8 deletions tests/models/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def test_headers():
assert h["a"] == "123, 456"
assert h.get("a") == "123, 456"
assert h.get("nope", default=None) is None
assert h.getlist("a") == ["123", "456"]
assert h.keys() == ["a", "a", "b"]
assert h.values() == ["123", "456", "789"]
assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
assert list(h) == ["a", "a", "b"]
assert h.get_list("a") == ["123", "456"]
assert h.keys() == ["a", "b"]
assert h.values() == ["123, 456", "789"]
assert h.items() == [("a", "123, 456"), ("b", "789")]
assert h.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")]
assert list(h) == ["a", "b"]
assert dict(h) == {"a": "123, 456", "b": "789"}
assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
assert h == httpx.Headers([("a", "123"), ("b", "789"), ("a", "456")])
Expand Down Expand Up @@ -153,13 +154,13 @@ def test_headers_decode_explicit_encoding():

def test_multiple_headers():
"""
Most headers should split by commas for `getlist`, except 'Set-Cookie'.
`Headers.get_list` should support both split_commas=False and split_commas=True.
"""
h = httpx.Headers([("set-cookie", "a, b"), ("set-cookie", "c")])
h.getlist("Set-Cookie") == ["a, b", "b"]
assert h.get_list("Set-Cookie") == ["a, b", "c"]

h = httpx.Headers([("vary", "a, b"), ("vary", "c")])
h.getlist("Vary") == ["a", "b", "c"]
assert h.get_list("Vary", split_commas=True) == ["a", "b", "c"]


@pytest.mark.parametrize("header", ["authorization", "proxy-authorization"])
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_queryparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_queryparams(source):
assert q["a"] == "456"
assert q.get("a") == "456"
assert q.get("nope", default=None) is None
assert q.getlist("a") == ["123", "456"]
assert q.get_list("a") == ["123", "456"]
assert list(q.keys()) == ["a", "b"]
assert list(q.values()) == ["456", "789"]
assert list(q.items()) == [("a", "456"), ("b", "789")]
Expand Down

0 comments on commit 2ba9c1e

Please sign in to comment.