From 283fb63b83691b744d28d7631fe038488f50064c Mon Sep 17 00:00:00 2001 From: Nipunn Koorapati Date: Thu, 15 Jul 2021 23:58:31 -0500 Subject: [PATCH] Upper bound protobuf containers based on scalar/composite --- .../google/protobuf/internal/containers.pyi | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/stubs/protobuf/google/protobuf/internal/containers.pyi b/stubs/protobuf/google/protobuf/internal/containers.pyi index 4d7b019eb0d5..5e54051b6afa 100644 --- a/stubs/protobuf/google/protobuf/internal/containers.pyi +++ b/stubs/protobuf/google/protobuf/internal/containers.pyi @@ -8,6 +8,7 @@ from typing import ( MutableMapping as MutableMapping, Optional, Sequence, + Text, TypeVar, Union, overload, @@ -16,10 +17,12 @@ from typing import ( from google.protobuf.descriptor import Descriptor from google.protobuf.internal.message_listener import MessageListener from google.protobuf.internal.python_message import GeneratedProtocolMessageType +from google.protobuf.message import Message _T = TypeVar("_T") -_K = TypeVar("_K") -_V = TypeVar("_V") +_K = TypeVar("_K", bound=Union[bool, int, Text]) +_ScalarV = TypeVar("_ScalarV", bound=Union[bool, int, float, Text, bytes]) +_MessageV = TypeVar("_MessageV", bound=Message) _M = TypeVar("_M") class BaseContainer(Sequence[_T]): @@ -34,42 +37,42 @@ class BaseContainer(Sequence[_T]): @overload def __getitem__(self, key: slice) -> List[_T]: ... -class RepeatedScalarFieldContainer(BaseContainer[_T]): +class RepeatedScalarFieldContainer(BaseContainer[_ScalarV]): def __init__(self, message_listener: MessageListener, message_descriptor: Descriptor) -> None: ... - def append(self, value: _T) -> None: ... - def insert(self, key: int, value: _T) -> None: ... - def extend(self, elem_seq: Optional[Iterable[_T]]) -> None: ... + def append(self, value: _ScalarV) -> None: ... + def insert(self, key: int, value: _ScalarV) -> None: ... + def extend(self, elem_seq: Optional[Iterable[_ScalarV]]) -> None: ... def MergeFrom(self: _M, other: _M) -> None: ... - def remove(self, elem: _T) -> None: ... - def pop(self, key: int = ...) -> _T: ... + def remove(self, elem: _ScalarV) -> None: ... + def pop(self, key: int = ...) -> _ScalarV: ... @overload - def __setitem__(self, key: int, value: _T) -> None: ... + def __setitem__(self, key: int, value: _ScalarV) -> None: ... @overload - def __setitem__(self, key: slice, value: Iterable[_T]) -> None: ... - def __getslice__(self, start: int, stop: int) -> List[_T]: ... - def __setslice__(self, start: int, stop: int, values: Iterable[_T]) -> None: ... + def __setitem__(self, key: slice, value: Iterable[_ScalarV]) -> None: ... + def __getslice__(self, start: int, stop: int) -> List[_ScalarV]: ... + def __setslice__(self, start: int, stop: int, values: Iterable[_ScalarV]) -> None: ... def __delitem__(self, key: Union[int, slice]) -> None: ... def __delslice__(self, start: int, stop: int) -> None: ... def __eq__(self, other: object) -> bool: ... -class RepeatedCompositeFieldContainer(BaseContainer[_T]): +class RepeatedCompositeFieldContainer(BaseContainer[_MessageV]): def __init__(self, message_listener: MessageListener, type_checker: Any) -> None: ... - def add(self, **kwargs: Any) -> _T: ... - def append(self, value: _T) -> None: ... - def insert(self, key: int, value: _T) -> None: ... - def extend(self, elem_seq: Iterable[_T]) -> None: ... + def add(self, **kwargs: Any) -> _MessageV: ... + def append(self, value: _MessageV) -> None: ... + def insert(self, key: int, value: _MessageV) -> None: ... + def extend(self, elem_seq: Iterable[_MessageV]) -> None: ... def MergeFrom(self: _M, other: _M) -> None: ... - def remove(self, elem: _T) -> None: ... - def pop(self, key: int = ...) -> _T: ... - def __getslice__(self, start: int, stop: int) -> List[_T]: ... + def remove(self, elem: _MessageV) -> None: ... + def pop(self, key: int = ...) -> _MessageV: ... + def __getslice__(self, start: int, stop: int) -> List[_MessageV]: ... def __delitem__(self, key: Union[int, slice]) -> None: ... def __delslice__(self, start: int, stop: int) -> None: ... def __eq__(self, other: object) -> bool: ... -class ScalarMap(MutableMapping[_K, _V]): - def __setitem__(self, k: _K, v: _V) -> None: ... +class ScalarMap(MutableMapping[_K, _ScalarV]): + def __setitem__(self, k: _K, v: _ScalarV) -> None: ... def __delitem__(self, v: _K) -> None: ... - def __getitem__(self, k: _K) -> _V: ... + def __getitem__(self, k: _K) -> _ScalarV: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[_K]: ... def __eq__(self, other: object) -> bool: ... @@ -77,14 +80,14 @@ class ScalarMap(MutableMapping[_K, _V]): def InvalidateIterators(self) -> None: ... def GetEntryClass(self) -> GeneratedProtocolMessageType: ... -class MessageMap(MutableMapping[_K, _V]): - def __setitem__(self, k: _K, v: _V) -> None: ... +class MessageMap(MutableMapping[_K, _MessageV]): + def __setitem__(self, k: _K, v: _MessageV) -> None: ... def __delitem__(self, v: _K) -> None: ... - def __getitem__(self, k: _K) -> _V: ... + def __getitem__(self, k: _K) -> _MessageV: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[_K]: ... def __eq__(self, other: object) -> bool: ... - def get_or_create(self, key: _K) -> _V: ... + def get_or_create(self, key: _K) -> _MessageV: ... def MergeFrom(self: _M, other: _M): ... def InvalidateIterators(self) -> None: ... def GetEntryClass(self) -> GeneratedProtocolMessageType: ...