Skip to content

Commit

Permalink
PYTHON-3821 use overload pattern for _DocumentType (#1352)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepyStick authored Aug 10, 2023
1 parent c1d3383 commit 0d44783
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
52 changes: 46 additions & 6 deletions bson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,21 @@ def _decode_all(
_decode_all = _cbson._decode_all # noqa: F811


@overload
def decode_all(data: "_ReadableBuffer", codec_options: None = None) -> "List[Dict[str, Any]]":
...


@overload
def decode_all(
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]"
) -> "List[_DocumentType]":
...


def decode_all(
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
) -> "Union[List[Dict[str, Any]], List[_DocumentType]]":
"""Decode BSON data to multiple documents.
`data` must be a bytes-like object implementing the buffer protocol that
Expand All @@ -1131,11 +1143,13 @@ def decode_all(
Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with
`codec_options`.
"""
opts = codec_options or DEFAULT_CODEC_OPTIONS
if not isinstance(opts, CodecOptions):
if codec_options is None:
return _decode_all(data, DEFAULT_CODEC_OPTIONS)

if not isinstance(codec_options, CodecOptions):
raise _CODEC_OPTIONS_TYPE_ERROR

return _decode_all(data, opts) # type:ignore[arg-type]
return _decode_all(data, codec_options)


def _decode_selective(rawdoc: Any, fields: Any, codec_options: Any) -> Mapping[Any, Any]:
Expand Down Expand Up @@ -1242,9 +1256,21 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) -
]


@overload
def decode_iter(data: bytes, codec_options: None = None) -> "Iterator[Dict[str, Any]]":
...


@overload
def decode_iter(
data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None
data: bytes, codec_options: "CodecOptions[_DocumentType]"
) -> "Iterator[_DocumentType]":
...


def decode_iter(
data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None
) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]":
"""Decode BSON data to multiple documents as a generator.
Works similarly to the decode_all function, but yields one document at a
Expand Down Expand Up @@ -1278,9 +1304,23 @@ def decode_iter(
yield _bson_to_dict(elements, opts)


@overload
def decode_file_iter(
file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None
file_obj: Union[BinaryIO, IO], codec_options: None = None
) -> "Iterator[Dict[str, Any]]":
...


@overload
def decode_file_iter(
file_obj: Union[BinaryIO, IO], codec_options: "CodecOptions[_DocumentType]"
) -> "Iterator[_DocumentType]":
...


def decode_file_iter(
file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None
) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]":
"""Decode bson data from a file to multiple documents as a generator.
Works similarly to the decode_all function, but reads from the file object
Expand Down
18 changes: 0 additions & 18 deletions pymongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,24 +427,6 @@ def database(self) -> Database[_DocumentType]:
"""
return self.__database

# @overload
# def with_options(
# self,
# codec_options: None = None,
# read_preference: Optional[_ServerMode] = None,
# write_concern: Optional[WriteConcern] = None,
# read_concern: Optional[ReadConcern] = None,
# ) -> Collection[Dict[str, Any]]: ...

# @overload
# def with_options(
# self,
# codec_options: bson.CodecOptions[_DocumentType],
# read_preference: Optional[_ServerMode] = None,
# write_concern: Optional[WriteConcern] = None,
# read_concern: Optional[ReadConcern] = None,
# ) -> Collection[_DocumentType]: ...

def with_options(
self,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
Expand Down
25 changes: 25 additions & 0 deletions test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def foo(self):
rt_document3 = decode(bsonbytes2, codec_options=codec_options2)
assert rt_document3.raw

def test_bson_decode_no_codec_option(self) -> None:
doc = decode_all(encode({"a": 1}))
assert doc
doc[0]["a"] = 2

def test_bson_decode_all(self) -> None:
doc = {"_id": 1}
bsonbytes = encode(doc)
Expand All @@ -266,6 +271,15 @@ def foo(self):
rt_documents3 = decode_all(bsonbytes3, codec_options3)
assert rt_documents3[0].raw

def test_bson_decode_all_no_codec_option(self) -> None:
docs = decode_all(b"")
docs.append({"new": 1})

docs = decode_all(encode({"a": 1}))
assert docs
docs[0]["a"] = 2
docs.append({"new": 1})

def test_bson_decode_iter(self) -> None:
doc = {"_id": 1}
bsonbytes = encode(doc)
Expand All @@ -290,6 +304,11 @@ def foo(self):
rt_documents3 = decode_iter(bsonbytes3, codec_options3)
assert next(rt_documents3).raw

def test_bson_decode_iter_no_codec_option(self) -> None:
doc = next(decode_iter(encode({"a": 1})))
assert doc
doc["a"] = 2

def make_tempfile(self, content: bytes) -> Any:
fileobj = tempfile.TemporaryFile()
fileobj.write(content)
Expand Down Expand Up @@ -324,6 +343,12 @@ def foo(self):
rt_documents3 = decode_file_iter(fileobj3, codec_options3)
assert next(rt_documents3).raw

def test_bson_decode_file_iter_none_codec_option(self) -> None:
fileobj = self.make_tempfile(encode({"new": 1}))
doc = next(decode_file_iter(fileobj))
assert doc
doc["a"] = 2


class TestDocumentType(unittest.TestCase):
@only_type_check
Expand Down

0 comments on commit 0d44783

Please sign in to comment.