diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 19403c442f..7a5d907da9 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Module containing wrapper classes around meta-descriptors. This module contains dataclasses which wrap the descriptor protos @@ -30,13 +31,13 @@ import dataclasses import re from itertools import chain -from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, ClassVar, - Optional, Sequence, Set, Tuple, Union) -from google.api import annotations_pb2 # type: ignore +from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, + ClassVar, Optional, Sequence, Set, Tuple, Union) +from google.api import annotations_pb2 # type: ignore from google.api import client_pb2 from google.api import field_behavior_pb2 from google.api import resource_pb2 -from google.api_core import exceptions # type: ignore +from google.api_core import exceptions # type: ignore from google.protobuf import descriptor_pb2 # type: ignore from google.protobuf.json_format import MessageToDict # type: ignore @@ -51,7 +52,8 @@ class Field: message: Optional['MessageType'] = None enum: Optional['EnumType'] = None meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) oneof: Optional[str] = None def __getattr__(self, name): @@ -67,7 +69,7 @@ def __hash__(self): def name(self) -> str: """Used to prevent collisions with python keywords""" name = self.field_pb.name - return name + '_' if name in utils.RESERVED_NAMES else name + return name + "_" if name in utils.RESERVED_NAMES else name @utils.cached_property def ident(self) -> metadata.FieldIdentifier: @@ -89,9 +91,9 @@ def map(self) -> bool: @utils.cached_property def mock_value(self) -> str: - visited_fields: Set['Field'] = set() + visited_fields: Set["Field"] = set() stack = [self] - answer = '{}' + answer = "{}" while stack: expr = stack.pop() answer = answer.format(expr.inner_mock(stack, visited_fields)) @@ -127,10 +129,13 @@ def inner_mock(self, stack, visited_fields): answer = f'{self.type.ident}.{mock_value.name}' # If this is another message, set one value on the message. - if (not self.map # Maps are handled separately - and isinstance(self.type, MessageType) and len(self.type.fields) - # Nested message types need to terminate eventually - and self not in visited_fields): + if ( + not self.map # Maps are handled separately + and isinstance(self.type, MessageType) + and len(self.type.fields) + # Nested message types need to terminate eventually + and self not in visited_fields + ): sub = next(iter(self.type.fields.values())) stack.append(sub) visited_fields.add(self) @@ -142,8 +147,8 @@ def inner_mock(self, stack, visited_fields): # Maps are a special case beacuse they're represented internally as # a list of a generated type with two fields: 'key' and 'value'. answer = '{{{}: {}}}'.format( - self.type.fields['key'].mock_value, - self.type.fields['value'].mock_value, + self.type.fields["key"].mock_value, + self.type.fields["value"].mock_value, ) elif self.repeated: # If this is a repeated field, then the mock answer should @@ -156,17 +161,17 @@ def inner_mock(self, stack, visited_fields): @property def proto_type(self) -> str: """Return the proto type constant to be used in templates.""" - return cast( - str, descriptor_pb2.FieldDescriptorProto.Type.Name( - self.field_pb.type,))[len('TYPE_'):] + return cast(str, descriptor_pb2.FieldDescriptorProto.Type.Name( + self.field_pb.type, + ))[len('TYPE_'):] @property def repeated(self) -> bool: """Return True if this is a repeated field, False otherwise. - Returns: - bool: Whether this field is repeated. - """ + Returns: + bool: Whether this field is repeated. + """ return self.label == \ descriptor_pb2.FieldDescriptorProto.Label.Value( 'LABEL_REPEATED') # type: ignore @@ -175,11 +180,11 @@ def repeated(self) -> bool: def required(self) -> bool: """Return True if this is a required field, False otherwise. - Returns: - bool: Whether this field is required. - """ - return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') - in self.options.Extensions[field_behavior_pb2.field_behavior]) + Returns: + bool: Whether this field is required. + """ + return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in + self.options.Extensions[field_behavior_pb2.field_behavior]) @utils.cached_property def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: @@ -216,17 +221,17 @@ def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: 'This code should not be reachable; please file a bug.') def with_context( - self, - *, - collisions: FrozenSet[str], - visited_messages: FrozenSet['MessageType'], + self, + *, + collisions: FrozenSet[str], + visited_messages: FrozenSet["MessageType"], ) -> 'Field': """Return a derivative of this field with the provided context. - This method is used to address naming collisions. The returned - ``Field`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Field`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, message=self.message.with_context( @@ -234,8 +239,8 @@ def with_context( skip_fields=self.message in visited_messages, visited_messages=visited_messages, ) if self.message else None, - enum=self.enum.with_context( - collisions=collisions) if self.enum else None, + enum=self.enum.with_context(collisions=collisions) + if self.enum else None, meta=self.meta.with_context(collisions=collisions), ) @@ -261,7 +266,8 @@ class MessageType: nested_enums: Mapping[str, 'EnumType'] nested_messages: Mapping[str, 'MessageType'] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) oneofs: Optional[Mapping[str, 'Oneof']] = None def __getattr__(self, name): @@ -282,14 +288,18 @@ def oneof_fields(self, include_optional=False): @utils.cached_property def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: - answer = tuple(field.type - for field in self.fields.values() - if field.message or field.enum) + answer = tuple( + field.type + for field in self.fields.values() + if field.message or field.enum + ) return answer @utils.cached_property - def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: + def recursive_field_types(self) -> Sequence[ + Union['MessageType', 'EnumType'] + ]: """Return all composite fields used in this proto's messages.""" types: Set[Union['MessageType', 'EnumType']] = set() @@ -308,13 +318,16 @@ def recursive_field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: def recursive_resource_fields(self) -> FrozenSet[Field]: all_fields = chain( self.fields.values(), - (field for t in self.recursive_field_types - if isinstance(t, MessageType) for field in t.fields.values()), + (field + for t in self.recursive_field_types if isinstance(t, MessageType) + for field in t.fields.values()), ) return frozenset( - f for f in all_fields + f + for f in all_fields if (f.options.Extensions[resource_pb2.resource_reference].type or - f.options.Extensions[resource_pb2.resource_reference].child_type)) + f.options.Extensions[resource_pb2.resource_reference].child_type) + ) @property def map(self) -> bool: @@ -329,11 +342,11 @@ def ident(self) -> metadata.Address: @property def resource_path(self) -> Optional[str]: """If this message describes a resource, return the path to the resource. - - If there are multiple paths, returns the first one. - """ + If there are multiple paths, returns the first one.""" return next( - iter(self.options.Extensions[resource_pb2.resource].pattern), None) + iter(self.options.Extensions[resource_pb2.resource].pattern), + None + ) @property def resource_type(self) -> Optional[str]: @@ -354,38 +367,41 @@ def path_regex_str(self) -> str: # becomes the regex # ^kingdoms/(?P.+?)/phyla/(?P.+?)$ parsing_regex_str = ( - '^' + self.PATH_ARG_RE.sub( + "^" + + self.PATH_ARG_RE.sub( # We can't just use (?P[^/]+) because segments may be # separated by delimiters other than '/'. # Multiple delimiter characters within one schema are allowed, # e.g. # as/{a}-{b}/cs/{c}%{d}_{e} # This is discouraged but permitted by AIP4231 - lambda m: '(?P<{name}>.+?)'.format(name=m.groups()[0]), - self.resource_path or '') + '$') + lambda m: "(?P<{name}>.+?)".format(name=m.groups()[0]), + self.resource_path or '' + ) + + "$" + ) return parsing_regex_str - def get_field( - self, *field_path: str, - collisions: FrozenSet[str] = frozenset()) -> Field: + def get_field(self, *field_path: str, + collisions: FrozenSet[str] = frozenset()) -> Field: """Return a field arbitrarily deep in this message's structure. - This method recursively traverses the message tree to return the - requested inner-field. + This method recursively traverses the message tree to return the + requested inner-field. - Traversing through repeated fields is not supported; a repeated field - may be specified if and only if it is the last field in the path. + Traversing through repeated fields is not supported; a repeated field + may be specified if and only if it is the last field in the path. - Args: - field_path (Sequence[str]): The field path. + Args: + field_path (Sequence[str]): The field path. - Returns: - ~.Field: A field object. + Returns: + ~.Field: A field object. - Raises: - KeyError: If a repeated field is used in the non-terminal position - in the path. - """ + Raises: + KeyError: If a repeated field is used in the non-terminal position + in the path. + """ # If collisions are not explicitly specified, retrieve them # from this message's address. # This ensures that calls to `get_field` will return a field with @@ -415,42 +431,43 @@ def get_field( '`get_field` to retrieve its children.\n' 'This exception usually indicates that a ' 'google.api.method_signature annotation uses a repeated field ' - 'in the fields list in a position other than the end.',) + 'in the fields list in a position other than the end.', + ) # Sanity check: If this cursor has no message, there is a problem. if not cursor.message: raise KeyError( f'Field {".".join(field_path)} could not be resolved from ' - f'{cursor.name}.',) + f'{cursor.name}.', + ) # Recursion case: Pass the remainder of the path to the sub-field's # message. return cursor.message.get_field(*field_path[1:], collisions=collisions) - def with_context( - self, - *, - collisions: FrozenSet[str], - skip_fields: bool = False, - visited_messages: FrozenSet['MessageType'] = frozenset(), - ) -> 'MessageType': + def with_context(self, *, + collisions: FrozenSet[str], + skip_fields: bool = False, + visited_messages: FrozenSet["MessageType"] = frozenset(), + ) -> 'MessageType': """Return a derivative of this message with the provided context. - This method is used to address naming collisions. The returned - ``MessageType`` object aliases module names to avoid naming collisions - in the file being written. + This method is used to address naming collisions. The returned + ``MessageType`` object aliases module names to avoid naming collisions + in the file being written. - The ``skip_fields`` argument will omit applying the context to the - underlying fields. This provides for an "exit" in the case of circular - references. - """ + The ``skip_fields`` argument will omit applying the context to the + underlying fields. This provides for an "exit" in the case of circular + references. + """ visited_messages = visited_messages | {self} return dataclasses.replace( self, fields={ k: v.with_context( - collisions=collisions, visited_messages=visited_messages) - for k, v in self.fields.items() + collisions=collisions, + visited_messages=visited_messages + ) for k, v in self.fields.items() } if not skip_fields else self.fields, nested_enums={ k: v.with_context(collisions=collisions) @@ -461,7 +478,8 @@ def with_context( collisions=collisions, skip_fields=skip_fields, visited_messages=visited_messages, - ) for k, v in self.nested_messages.items() + ) + for k, v in self.nested_messages.items() }, meta=self.meta.with_context(collisions=collisions), ) @@ -472,7 +490,8 @@ class EnumValueType: """Description of an enum value.""" enum_value_pb: descriptor_pb2.EnumValueDescriptorProto meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __getattr__(self, name): return getattr(self.enum_value_pb, name) @@ -484,7 +503,8 @@ class EnumType: enum_pb: descriptor_pb2.EnumDescriptorProto values: List[EnumValueType] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __hash__(self): # Identity is sufficiently unambiguous. @@ -508,10 +528,10 @@ def ident(self) -> metadata.Address: def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': """Return a derivative of this enum with the provided context. - This method is used to address naming collisions. The returned - ``EnumType`` object aliases module names to avoid naming collisions in - the file being written. - """ + This method is used to address naming collisions. The returned + ``EnumType`` object aliases module names to avoid naming collisions in + the file being written. + """ return dataclasses.replace( self, meta=self.meta.with_context(collisions=collisions), @@ -521,20 +541,23 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': def options_dict(self) -> Dict: """Return the EnumOptions (if present) as a dict. - This is a hack to support a pythonic structure representation for - the generator templates. - """ - return MessageToDict(self.enum_pb.options, preserving_proto_field_name=True) + This is a hack to support a pythonic structure representation for + the generator templates. + """ + return MessageToDict( + self.enum_pb.options, + preserving_proto_field_name=True + ) @dataclasses.dataclass(frozen=True) class PythonType: """Wrapper class for Python types. - This exists for interface consistency, so that methods like - :meth:`Field.type` can return an object and the caller can be confident - that a ``name`` property will be present. - """ + This exists for interface consistency, so that methods like + :meth:`Field.type` can return an object and the caller can be confident + that a ``name`` property will be present. + """ meta: metadata.Metadata def __eq__(self, other): @@ -566,22 +589,19 @@ class PrimitiveType(PythonType): def build(cls, primitive_type: Optional[type]): """Return a PrimitiveType object for the given Python primitive type. - Args: - primitive_type (cls): A Python primitive type, such as :class:`int` - or :class:`str`. Despite not being a type, ``None`` is also - accepted here. + Args: + primitive_type (cls): A Python primitive type, such as + :class:`int` or :class:`str`. Despite not being a type, + ``None`` is also accepted here. - Returns: - ~.PrimitiveType: The instantiated PrimitiveType object. - """ + Returns: + ~.PrimitiveType: The instantiated PrimitiveType object. + """ # Primitives have no import, and no module to reference, so the # address just uses the name of the class (e.g. "int", "str"). - return cls( - meta=metadata.Metadata( - address=metadata.Address( - name='None' if primitive_type is None else primitive_type - .__name__,)), - python_type=primitive_type) + return cls(meta=metadata.Metadata(address=metadata.Address( + name='None' if primitive_type is None else primitive_type.__name__, + )), python_type=primitive_type) def __eq__(self, other): # If we are sent the actual Python type (not the PrimitiveType object), @@ -600,17 +620,18 @@ class OperationInfo: def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': """Return a derivative of this OperationInfo with the provided context. - This method is used to address naming collisions. The returned - ``OperationInfo`` object aliases module names to avoid naming - collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``OperationInfo`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, response_type=self.response_type.with_context( - collisions=collisions), + collisions=collisions + ), metadata_type=self.metadata_type.with_context( - collisions=collisions), + collisions=collisions + ), ) @@ -634,7 +655,8 @@ class Method: retry: Optional[RetryInfo] = dataclasses.field(default=None) timeout: Optional[float] = None meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) def __getattr__(self, name): return getattr(self.method_pb, name) @@ -659,13 +681,13 @@ def flattened_oneof_fields(self, include_optional=False): def _client_output(self, enable_asyncio: bool): """Return the output from the client layer. - This takes into account transformations made by the outer GAPIC - client to transform the output from the transport. + This takes into account transformations made by the outer GAPIC + client to transform the output from the transport. - Returns: - Union[~.MessageType, ~.PythonType]: - A description of the return type. - """ + Returns: + Union[~.MessageType, ~.PythonType]: + A description of the return type. + """ # Void messages ultimately return None. if self.void: return PrimitiveType.build(None) @@ -673,44 +695,41 @@ def _client_output(self, enable_asyncio: bool): # If this method is an LRO, return a PythonType instance representing # that. if self.lro: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name='AsyncOperation' if enable_asyncio else 'Operation', - module='operation_async' if enable_asyncio else 'operation', - package=('google', 'api_core'), - collisions=self.lro.response_type.ident.collisions, + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name='AsyncOperation' if enable_asyncio else 'Operation', + module='operation_async' if enable_asyncio else 'operation', + package=('google', 'api_core'), + collisions=self.lro.response_type.ident.collisions, + ), + documentation=utils.doc( + 'An object representing a long-running operation. \n\n' + 'The result type for the operation will be ' + ':class:`{ident}` {doc}'.format( + doc=self.lro.response_type.meta.doc, + ident=self.lro.response_type.ident.sphinx, ), - documentation=utils.doc( - 'An object representing a long-running operation. \n\n' - 'The result type for the operation will be ' - ':class:`{ident}` {doc}'.format( - doc=self.lro.response_type.meta.doc, - ident=self.lro.response_type.ident.sphinx, - ),), - )) + ), + )) # If this method is paginated, return that method's pager class. if self.paged_result_field: - return PythonType( - meta=metadata.Metadata( - address=metadata.Address( - name=f'{self.name}AsyncPager' - if enable_asyncio else f'{self.name}Pager', - package=self.ident.api_naming.module_namespace + - (self.ident.api_naming.versioned_module_name,) + - self.ident.subpackage + ( - 'services', - utils.to_snake_case(self.ident.parent[-1]), - ), - module='pagers', - collisions=self.input.ident.collisions, + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager', + package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( + 'services', + utils.to_snake_case(self.ident.parent[-1]), ), - documentation=utils.doc( - f'{self.output.meta.doc}\n\n' - 'Iterating over this object will yield results and ' - 'resolve additional pages automatically.',), - )) + module='pagers', + collisions=self.input.ident.collisions, + ), + documentation=utils.doc( + f'{self.output.meta.doc}\n\n' + 'Iterating over this object will yield results and ' + 'resolve additional pages automatically.', + ), + )) # Return the usual output. return self.output @@ -720,6 +739,7 @@ def is_deprecated(self) -> bool: """Returns true if the method is deprecated, false otherwise.""" return descriptor_pb2.MethodOptions.HasField(self.options, 'deprecated') + # TODO(yon-mg): remove or rewrite: don't think it performs as intended # e.g. doesn't work with basic case of gRPC transcoding @property @@ -738,18 +758,17 @@ def field_headers(self) -> Sequence[str]: http.custom.path, ] - return next( - (tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) + return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) @property def http_opt(self) -> Optional[Dict[str, str]]: """Return the http option for this method. - e.g. {'verb': 'post' - 'url': '/some/path' - 'body': '*'} + e.g. {'verb': 'post' + 'url': '/some/path' + 'body': '*'} - """ + """ http: List[Tuple[descriptor_pb2.FieldDescriptorProto, str]] http = self.options.Extensions[annotations_pb2.http].ListFields() @@ -817,8 +836,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: signatures = self.options.Extensions[client_pb2.method_signature] answer: Dict[str, Field] = collections.OrderedDict( - name_and_field for sig in signatures - for name_and_field in filter_fields(sig)) + name_and_field + for sig in signatures + for name_and_field in filter_fields(sig) + ) return answer @@ -829,13 +850,11 @@ def flattened_field_to_key(self): @utils.cached_property def legacy_flattened_fields(self) -> Mapping[str, Field]: """Return the legacy flattening interface: top level fields only, - - required fields first - """ + required fields first""" required, optional = utils.partition(lambda f: f.required, self.input.fields.values()) - return collections.OrderedDict( - (f.name, f) for f in chain(required, optional)) + return collections.OrderedDict((f.name, f) + for f in chain(required, optional)) @property def grpc_stub_type(self) -> str: @@ -850,10 +869,10 @@ def grpc_stub_type(self) -> str: def idempotent(self) -> bool: """Return True if we know this method is idempotent, False otherwise. - Note: We are intentionally conservative here. It is far less bad - to falsely believe an idempotent method is non-idempotent than - the converse. - """ + Note: We are intentionally conservative here. It is far less bad + to falsely believe an idempotent method is non-idempotent than + the converse. + """ return bool(self.options.Extensions[annotations_pb2.http].get) @property @@ -877,7 +896,8 @@ def paged_result_field(self) -> Optional[Field]: # The request must have max_results or page_size page_fields = (self.input.fields.get('max_results', None), self.input.fields.get('page_size', None)) - page_field_size = next((field for field in page_fields if field), None) + page_field_size = next( + (field for field in page_fields if field), None) if not page_field_size or page_field_size.type != int: return None @@ -897,14 +917,18 @@ def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: return self._ref_types(False) - def _ref_types(self, - recursive: bool) -> Sequence[Union[MessageType, EnumType]]: + def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: """Return types referenced by this method.""" # Begin with the input (request) and output (response) messages. answer: List[Union[MessageType, EnumType]] = [self.input] types: Iterable[Union[MessageType, EnumType]] = ( - self.input.recursive_field_types if recursive else - (f.type for f in self.flattened_fields.values() if f.message or f.enum)) + self.input.recursive_field_types if recursive + else ( + f.type + for f in self.flattened_fields.values() + if f.message or f.enum + ) + ) answer.extend(types) if not self.void: @@ -935,14 +959,15 @@ def void(self) -> bool: def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': """Return a derivative of this method with the provided context. - This method is used to address naming collisions. The returned - ``Method`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Method`` object aliases module names to avoid naming collisions + in the file being written. + """ maybe_lro = None if self.lro: maybe_lro = self.lro.with_context( - collisions=collisions) if collisions else self.lro + collisions=collisions + ) if collisions else self.lro return dataclasses.replace( self, @@ -960,7 +985,10 @@ class CommonResource: @classmethod def build(cls, resource: resource_pb2.ResourceDescriptor): - return cls(type_name=resource.type, pattern=next(iter(resource.pattern))) + return cls( + type_name=resource.type, + pattern=next(iter(resource.pattern)) + ) @utils.cached_property def message_type(self): @@ -987,35 +1015,31 @@ class Service: # This is represented by a types.MappingProxyType instance. visible_resources: Mapping[str, MessageType] meta: metadata.Metadata = dataclasses.field( - default_factory=metadata.Metadata,) + default_factory=metadata.Metadata, + ) common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( default={ - 'cloudresourcemanager.googleapis.com/Project': - CommonResource( - 'cloudresourcemanager.googleapis.com/Project', - 'projects/{project}', - ), - 'cloudresourcemanager.googleapis.com/Organization': - CommonResource( - 'cloudresourcemanager.googleapis.com/Organization', - 'organizations/{organization}', - ), - 'cloudresourcemanager.googleapis.com/Folder': - CommonResource( - 'cloudresourcemanager.googleapis.com/Folder', - 'folders/{folder}', - ), - 'cloudbilling.googleapis.com/BillingAccount': - CommonResource( - 'cloudbilling.googleapis.com/BillingAccount', - 'billingAccounts/{billing_account}', - ), - 'locations.googleapis.com/Location': - CommonResource( - 'locations.googleapis.com/Location', - 'projects/{project}/locations/{location}', - ), + "cloudresourcemanager.googleapis.com/Project": CommonResource( + "cloudresourcemanager.googleapis.com/Project", + "projects/{project}", + ), + "cloudresourcemanager.googleapis.com/Organization": CommonResource( + "cloudresourcemanager.googleapis.com/Organization", + "organizations/{organization}", + ), + "cloudresourcemanager.googleapis.com/Folder": CommonResource( + "cloudresourcemanager.googleapis.com/Folder", + "folders/{folder}", + ), + "cloudbilling.googleapis.com/BillingAccount": CommonResource( + "cloudbilling.googleapis.com/BillingAccount", + "billingAccounts/{billing_account}", + ), + "locations.googleapis.com/Location": CommonResource( + "locations.googleapis.com/Location", + "projects/{project}/locations/{location}", + ), }, init=False, compare=False, @@ -1027,28 +1051,28 @@ def __getattr__(self, name): @property def client_name(self) -> str: """Returns the name of the generated client class""" - return self.name + 'Client' + return self.name + "Client" @property def async_client_name(self) -> str: """Returns the name of the generated AsyncIO client class""" - return self.name + 'AsyncClient' + return self.name + "AsyncClient" @property def transport_name(self): - return self.name + 'Transport' + return self.name + "Transport" @property def grpc_transport_name(self): - return self.name + 'GrpcTransport' + return self.name + "GrpcTransport" @property def grpc_asyncio_transport_name(self): - return self.name + 'GrpcAsyncIOTransport' + return self.name + "GrpcAsyncIOTransport" @property def rest_transport_name(self): - return self.name + 'RestTransport' + return self.name + "RestTransport" @property def has_lro(self) -> bool: @@ -1064,61 +1088,61 @@ def has_pagers(self) -> bool: def host(self) -> str: """Return the hostname for this service, if specified. - Returns: - str: The hostname, with no protocol and no trailing ``/``. - """ + Returns: + str: The hostname, with no protocol and no trailing ``/``. + """ if self.options.Extensions[client_pb2.default_host]: return self.options.Extensions[client_pb2.default_host] return '' @property def shortname(self) -> str: - """Return the API short name. + """Return the API short name. DRIFT uses this to identify + APIs. - DRIFT uses this to identify - APIs. - - Returns: - str: The api shortname. - """ + Returns: + str: The api shortname. + """ # Get the shortname from the host # Real APIs are expected to have format: # "{api_shortname}.googleapis.com" - return self.host.split('.')[0] + return self.host.split(".")[0] @property def oauth_scopes(self) -> Sequence[str]: """Return a sequence of oauth scopes, if applicable. - Returns: - Sequence[str]: A sequence of OAuth scopes. - """ + Returns: + Sequence[str]: A sequence of OAuth scopes. + """ # Return the OAuth scopes, split on comma. return tuple( i.strip() for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') - if i) + if i + ) @property def module_name(self) -> str: """Return the appropriate module name for this service. - Returns: - str: The service name, in snake case. - """ + Returns: + str: The service name, in snake case. + """ return utils.to_snake_case(self.name) @utils.cached_property def names(self) -> FrozenSet[str]: """Return a set of names used in this service. - This is used for detecting naming collisions in the module names - used for imports. - """ + This is used for detecting naming collisions in the module names + used for imports. + """ # Put together a set of the service and method names. answer = {self.name, self.client_name, self.async_client_name} - answer.update(utils.to_snake_case(i.name) - for i in self.methods.values()) + answer.update( + utils.to_snake_case(i.name) for i in self.methods.values() + ) # Identify any import module names where the same module name is used # from distinct packages. @@ -1127,8 +1151,11 @@ def names(self) -> FrozenSet[str]: for t in m.ref_types: modules[t.ident.module].add(t.ident.package) - answer.update(module_name for module_name, packages in modules.items() - if len(packages) > 1) + answer.update( + module_name + for module_name, packages in modules.items() + if len(packages) > 1 + ) # Done; return the answer. return frozenset(answer) @@ -1136,10 +1163,7 @@ def names(self) -> FrozenSet[str]: @utils.cached_property def resource_messages(self) -> FrozenSet[MessageType]: """Returns all the resource message types used in all - - request and response fields in the service. - """ - + request and response fields in the service.""" def gen_resources(message): if message.resource_path: yield message @@ -1150,7 +1174,8 @@ def gen_resources(message): def gen_indirect_resources_used(message): for field in message.recursive_resource_fields: - resource = field.options.Extensions[resource_pb2.resource_reference] + resource = field.options.Extensions[ + resource_pb2.resource_reference] resource_type = resource.type or resource.child_type # The resource may not be visible if the resource type is one of # the common_resources (see the class var in class definition) @@ -1159,14 +1184,20 @@ def gen_indirect_resources_used(message): if resource: yield resource - return frozenset(msg for method in self.methods.values() for msg in chain( - gen_resources(method.input), - gen_resources( - method.lro.response_type if method.lro else method.output), - gen_indirect_resources_used(method.input), - gen_indirect_resources_used( - method.lro.response_type if method.lro else method.output), - )) + return frozenset( + msg + for method in self.methods.values() + for msg in chain( + gen_resources(method.input), + gen_resources( + method.lro.response_type if method.lro else method.output + ), + gen_indirect_resources_used(method.input), + gen_indirect_resources_used( + method.lro.response_type if method.lro else method.output + ), + ) + ) @utils.cached_property def any_client_streaming(self) -> bool: @@ -1179,10 +1210,10 @@ def any_server_streaming(self) -> bool: def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': """Return a derivative of this service with the provided context. - This method is used to address naming collisions. The returned - ``Service`` object aliases module names to avoid naming collisions - in the file being written. - """ + This method is used to address naming collisions. The returned + ``Service`` object aliases module names to avoid naming collisions + in the file being written. + """ return dataclasses.replace( self, methods={ diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index 69c3b7cf07..a499606f49 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -25,12 +25,13 @@ def make_service( - name: str = 'Placeholder', - host: str = '', + name: str = "Placeholder", + host: str = "", methods: typing.Tuple[wrappers.Method] = (), scopes: typing.Tuple[str] = (), - visible_resources: typing.Optional[typing.Mapping[ - str, wrappers.CommonResource]] = None, + visible_resources: typing.Optional[ + typing.Mapping[str, wrappers.CommonResource] + ] = None, ) -> wrappers.Service: visible_resources = visible_resources or {} # Define a service descriptor, and set a host and oauth scopes if @@ -55,8 +56,7 @@ def make_service_with_method_options( http_rule: http_pb2.HttpRule = None, method_signature: str = '', in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - visible_resources: typing.Optional[typing.Mapping[ - str, wrappers.CommonResource]] = None, + visible_resources: typing.Optional[typing.Mapping[str, wrappers.CommonResource]] = None, ) -> wrappers.Service: # Declare a method with options enabled for long-running operations and # field headers. @@ -82,17 +82,15 @@ def make_service_with_method_options( ) -def get_method( - name: str, - in_type: str, - out_type: str, - lro_response_type: str = '', - lro_metadata_type: str = '', - *, - in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), - http_rule: http_pb2.HttpRule = None, - method_signature: str = '', -) -> wrappers.Method: +def get_method(name: str, + in_type: str, + out_type: str, + lro_response_type: str = '', + lro_metadata_type: str = '', *, + in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), + http_rule: http_pb2.HttpRule = None, + method_signature: str = '', + ) -> wrappers.Method: input_ = get_message(in_type, fields=in_fields) output = get_message(out_type) lro = None @@ -124,11 +122,9 @@ def get_method( ) -def get_message( - dot_path: str, - *, - fields: typing.Tuple[desc.FieldDescriptorProto] = (), -) -> wrappers.MessageType: +def get_message(dot_path: str, *, + fields: typing.Tuple[desc.FieldDescriptorProto] = (), + ) -> wrappers.MessageType: # Pass explicit None through (for lro_metadata). if dot_path is None: return None @@ -143,33 +139,32 @@ def get_message( pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] return wrappers.MessageType( - fields={ - i.name: wrappers.Field( - field_pb=i, - enum=get_enum(i.type_name) if i.type_name else None, - ) for i in fields - }, + fields={i.name: wrappers.Field( + field_pb=i, + enum=get_enum(i.type_name) if i.type_name else None, + ) for i in fields}, nested_messages={}, nested_enums={}, message_pb=desc.DescriptorProto(name=name, field=fields), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), ) -def make_method(name: str, - input_message: wrappers.MessageType = None, - output_message: wrappers.MessageType = None, - package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', - module: str = 'baz', - http_rule: http_pb2.HttpRule = None, - signatures: typing.Sequence[str] = (), - is_deprecated: bool = False, - **kwargs) -> wrappers.Method: +def make_method( + name: str, + input_message: wrappers.MessageType = None, + output_message: wrappers.MessageType = None, + package: typing.Union[typing.Tuple[str], str] = 'foo.bar.v1', + module: str = 'baz', + http_rule: http_pb2.HttpRule = None, + signatures: typing.Sequence[str] = (), + is_deprecated: bool = False, + **kwargs +) -> wrappers.Method: # Use default input and output messages if they are not provided. input_message = input_message or make_message('MethodInput') output_message = output_message or make_message('MethodOutput') @@ -179,7 +174,8 @@ def make_method(name: str, name=name, input_type=str(input_message.meta.address), output_type=str(output_message.meta.address), - **kwargs) + **kwargs + ) # If there is an HTTP rule, process it. if http_rule: @@ -202,31 +198,32 @@ def make_method(name: str, method_pb=method_pb, input=input_message, output=output_message, - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=package, - module=module, - parent=(f'{name}Service',), - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=package, + module=module, + parent=(f'{name}Service',), + )), ) -def make_field(name: str = 'my_field', - number: int = 1, - repeated: bool = False, - message: wrappers.MessageType = None, - enum: wrappers.EnumType = None, - meta: metadata.Metadata = None, - oneof: str = None, - **kwargs) -> wrappers.Field: +def make_field( + name: str = 'my_field', + number: int = 1, + repeated: bool = False, + message: wrappers.MessageType = None, + enum: wrappers.EnumType = None, + meta: metadata.Metadata = None, + oneof: str = None, + **kwargs +) -> wrappers.Field: T = desc.FieldDescriptorProto.Type if message: kwargs.setdefault('type_name', str(message.meta.address)) kwargs['type'] = 'TYPE_MESSAGE' elif enum: - kwargs.setdefault('type_name', str(enum.meta.address)) + kwargs.setdefault('type_name', str(enum.meta.address)) kwargs['type'] = 'TYPE_ENUM' else: kwargs.setdefault('type', T.Value('TYPE_BOOL')) @@ -236,7 +233,11 @@ def make_field(name: str = 'my_field', label = kwargs.pop('label', 3 if repeated else 1) field_pb = desc.FieldDescriptorProto( - name=name, label=label, number=number, **kwargs) + name=name, + label=label, + number=number, + **kwargs + ) return wrappers.Field( field_pb=field_pb, @@ -265,12 +266,11 @@ def make_message( fields=collections.OrderedDict((i.name, i) for i in fields), nested_messages={}, nested_enums={}, - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), + meta=meta or metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), ) @@ -279,12 +279,11 @@ def get_enum(dot_path: str) -> wrappers.EnumType: pkg, module, name = pieces[:-2], pieces[-2], pieces[-1] return wrappers.EnumType( enum_pb=desc.EnumDescriptorProto(name=name), - meta=metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(pkg), - module=module, - )), + meta=metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(pkg), + module=module, + )), values=[], ) @@ -298,7 +297,8 @@ def make_enum( options: desc.EnumOptions = None, ) -> wrappers.EnumType: enum_value_pbs = [ - desc.EnumValueDescriptorProto(name=i[0], number=i[1]) for i in values + desc.EnumValueDescriptorProto(name=i[0], number=i[1]) + for i in values ] enum_pb = desc.EnumDescriptorProto( name=name, @@ -307,15 +307,13 @@ def make_enum( ) return wrappers.EnumType( enum_pb=enum_pb, - values=[ - wrappers.EnumValueType(enum_value_pb=evpb) for evpb in enum_value_pbs - ], - meta=meta or metadata.Metadata( - address=metadata.Address( - name=name, - package=tuple(package.split('.')), - module=module, - )), + values=[wrappers.EnumValueType(enum_value_pb=evpb) + for evpb in enum_value_pbs], + meta=meta or metadata.Metadata(address=metadata.Address( + name=name, + package=tuple(package.split('.')), + module=module, + )), ) @@ -327,31 +325,33 @@ def make_naming(**kwargs) -> naming.Naming: return naming.NewNaming(**kwargs) -def make_enum_pb2(name: str, *values: typing.Sequence[str], - **kwargs) -> desc.EnumDescriptorProto: +def make_enum_pb2( + name: str, + *values: typing.Sequence[str], + **kwargs +) -> desc.EnumDescriptorProto: enum_value_pbs = [ desc.EnumValueDescriptorProto(name=n, number=i) for i, n in enumerate(values) ] - enum_pb = desc.EnumDescriptorProto( - name=name, value=enum_value_pbs, **kwargs) + enum_pb = desc.EnumDescriptorProto(name=name, value=enum_value_pbs, **kwargs) return enum_pb -def make_message_pb2(name: str, - fields: tuple = (), - oneof_decl: tuple = (), - **kwargs) -> desc.DescriptorProto: - return desc.DescriptorProto( - name=name, field=fields, oneof_decl=oneof_decl, **kwargs) - - -def make_field_pb2( +def make_message_pb2( name: str, - number: int, - type: int = 11, # 11 == message - type_name: str = None, - oneof_index: int = None) -> desc.FieldDescriptorProto: + fields: tuple = (), + oneof_decl: tuple = (), + **kwargs +) -> desc.DescriptorProto: + return desc.DescriptorProto(name=name, field=fields, oneof_decl=oneof_decl, **kwargs) + + +def make_field_pb2(name: str, number: int, + type: int = 11, # 11 == message + type_name: str = None, + oneof_index: int = None + ) -> desc.FieldDescriptorProto: return desc.FieldDescriptorProto( name=name, number=number, @@ -360,20 +360,18 @@ def make_field_pb2( oneof_index=oneof_index, ) - def make_oneof_pb2(name: str) -> desc.OneofDescriptorProto: - return desc.OneofDescriptorProto(name=name,) + return desc.OneofDescriptorProto( + name=name, + ) -def make_file_pb2( - name: str = 'my_proto.proto', - package: str = 'example.v1', - *, - messages: typing.Sequence[desc.DescriptorProto] = (), - enums: typing.Sequence[desc.EnumDescriptorProto] = (), - services: typing.Sequence[desc.ServiceDescriptorProto] = (), - locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), -) -> desc.FileDescriptorProto: +def make_file_pb2(name: str = 'my_proto.proto', package: str = 'example.v1', *, + messages: typing.Sequence[desc.DescriptorProto] = (), + enums: typing.Sequence[desc.EnumDescriptorProto] = (), + services: typing.Sequence[desc.ServiceDescriptorProto] = (), + locations: typing.Sequence[desc.SourceCodeInfo.Location] = (), + ) -> desc.FileDescriptorProto: return desc.FileDescriptorProto( name=name, package=package, @@ -385,14 +383,15 @@ def make_file_pb2( def make_doc_meta( - *, - leading: str = '', - trailing: str = '', - detached: typing.List[str] = [], + *, + leading: str = '', + trailing: str = '', + detached: typing.List[str] = [], ) -> desc.SourceCodeInfo.Location: return metadata.Metadata( documentation=desc.SourceCodeInfo.Location( leading_comments=leading, trailing_comments=trailing, leading_detached_comments=detached, - ),) + ), + ) diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 6168f58564..c13a9afb28 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -33,8 +33,8 @@ def test_method_types(): input_msg = make_message(name='Input', module='baz') output_msg = make_message(name='Output', module='baz') - method = make_method( - 'DoSomething', input_msg, output_msg, package='foo.bar', module='bacon') + method = make_method('DoSomething', input_msg, output_msg, + package='foo.bar', module='bacon') assert method.name == 'DoSomething' assert method.input.name == 'Input' assert method.output.name == 'Output' @@ -71,22 +71,19 @@ def test_method_client_output_empty(): def test_method_client_output_paged(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - parent = make_field(name='parent', type=9) # str - page_size = make_field(name='page_size', type=5) # int + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int page_token = make_field(name='page_token', type=9) # str - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - page_size, - page_token, - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + page_size, + page_token, + )) + output_msg = make_message(name='ListFoosResponse', fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) method = make_method( 'ListFoos', input_message=input_msg, @@ -96,12 +93,11 @@ def test_method_client_output_paged(): assert method.client_output.ident.name == 'ListFoosPager' max_results = make_field(name='max_results', type=5) # int - input_msg = make_message( - name='ListFoosRequest', fields=( - parent, - max_results, - page_token, - )) + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) method = make_method( 'ListFoos', input_message=input_msg, @@ -119,47 +115,36 @@ def test_method_client_output_async_empty(): def test_method_paged_result_field_not_first(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='next_page_token', type=9), # str - paged, - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + input_msg = make_message(name='ListFoosRequest', fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message(name='ListFoosResponse', fields=( + make_field(name='next_page_token', type=9), # str + paged, + )) + method = make_method('ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field == paged def test_method_paged_result_field_no_page_field(): - input_msg = make_message( - name='ListFoosRequest', - fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str - )) - output_msg = make_message( - name='ListFoosResponse', - fields=( - make_field(name='foos', message=make_message( - 'Foo'), repeated=False), - make_field(name='next_page_token', type=9), # str - )) - method = make_method( - 'ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + input_msg = make_message(name='ListFoosRequest', fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + )) + output_msg = make_message(name='ListFoosResponse', fields=( + make_field(name='foos', message=make_message('Foo'), repeated=False), + make_field(name='next_page_token', type=9), # str + )) + method = make_method('ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field is None method = make_method( @@ -171,7 +156,8 @@ def test_method_paged_result_field_no_page_field(): output_message=make_message( name='FooResponse', fields=(make_field(name='next_page_token', type=9),) # str - )) + ) + ) assert method.paged_result_field is None @@ -179,8 +165,8 @@ def test_method_paged_result_ref_types(): input_msg = make_message( name='ListSquidsRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int make_field(name='page_token', type=9), # str ), module='squid', @@ -192,12 +178,14 @@ def test_method_paged_result_ref_types(): make_field(name='molluscs', message=mollusc_msg, repeated=True), make_field(name='next_page_token', type=9) # str ), - module='mollusc') + module='mollusc' + ) method = make_method( 'ListSquids', input_message=input_msg, output_message=output_msg, - module='squid') + module='squid' + ) ref_type_names = {t.name for t in method.ref_types} assert ref_type_names == { @@ -233,7 +221,12 @@ def test_flattened_ref_types(): ), ), ), - make_field('stratum', enum=make_enum('Stratum',)), + make_field( + 'stratum', + enum=make_enum( + 'Stratum', + ) + ), ), ), signatures=('cephalopod.squid,stratum',), @@ -251,21 +244,19 @@ def test_flattened_ref_types(): def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) # str + paged = make_field(name='squids', type=9, repeated=True) # str input_msg = make_message( name='ListSquidsRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int make_field(name='page_token', type=9), # str ), ) - output_msg = make_message( - name='ListFoosResponse', - fields=( - paged, - make_field(name='next_page_token', type=9), # str - )) + output_msg = make_message(name='ListFoosResponse', fields=( + paged, + make_field(name='next_page_token', type=9), # str + )) method = make_method( 'ListSquids', input_message=input_msg, @@ -297,15 +288,15 @@ def test_method_field_headers_present(): def test_method_http_opt(): http_rule = http_pb2.HttpRule( - post='/v1/{parent=projects/*}/topics', body='*') + post='/v1/{parent=projects/*}/topics', + body='*' + ) method = make_method('DoSomething', http_rule=http_rule) assert method.http_opt == { 'verb': 'post', 'url': '/v1/{parent=projects/*}/topics', 'body': '*' } - - # TODO(yon-mg) to test: grpc transcoding, # correct handling of path/query params # correct handling of body & additional binding @@ -339,13 +330,20 @@ def test_method_path_params_no_http_rule(): def test_method_query_params(): # tests only the basic case of grpc transcoding - http_rule = http_pb2.HttpRule(post='/v1/{project}/topics', body='address') + http_rule = http_pb2.HttpRule( + post='/v1/{project}/topics', + body='address' + ) input_message = make_message( 'MethodInput', - fields=(make_field('region'), make_field('project'), - make_field('address'))) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) + fields=( + make_field('region'), + make_field('project'), + make_field('address') + ) + ) + method = make_method('DoSomething', http_rule=http_rule, + input_message=input_message) assert method.query_params == {'region'} @@ -353,12 +351,14 @@ def test_method_query_params_no_body(): # tests only the basic case of grpc transcoding http_rule = http_pb2.HttpRule(post='/v1/{project}/topics') input_message = make_message( - 'MethodInput', fields=( + 'MethodInput', + fields=( make_field('region'), make_field('project'), - )) - method = make_method( - 'DoSomething', http_rule=http_rule, input_message=input_message) + ) + ) + method = make_method('DoSomething', http_rule=http_rule, + input_message=input_message) assert method.query_params == {'region'} @@ -432,22 +432,18 @@ def test_method_flattened_fields_different_package_non_primitive(): # directly to its fields, which complicates request construction. # The easiest solution in this case is to just prohibit these fields # in the method flattening. - message = make_message( - 'Mantle', package='mollusc.cephalopod.v1', module='squid') - mantle = make_field( - 'mantle', type=11, type_name='Mantle', message=message, meta=message.meta) + message = make_message('Mantle', + package="mollusc.cephalopod.v1", module="squid") + mantle = make_field('mantle', type=11, type_name='Mantle', + message=message, meta=message.meta) arms_count = make_field('arms_count', type=5, meta=message.meta) input_message = make_message( - 'Squid', - fields=(mantle, arms_count), - package='.'.join(message.meta.address.package), - module=message.meta.address.module) - method = make_method( - 'PutSquid', - input_message=input_message, - package='remote.package.v1', - module='module', - signatures=('mantle,arms_count',)) + 'Squid', fields=(mantle, arms_count), + package=".".join(message.meta.address.package), + module=message.meta.address.module + ) + method = make_method('PutSquid', input_message=input_message, + package="remote.package.v1", module="module", signatures=("mantle,arms_count",)) assert set(method.flattened_fields) == {'arms_count'} @@ -463,63 +459,75 @@ def test_method_include_flattened_message_fields(): def test_method_legacy_flattened_fields(): required_options = descriptor_pb2.FieldOptions() required_options.Extensions[field_behavior_pb2.field_behavior].append( - field_behavior_pb2.FieldBehavior.Value('REQUIRED')) + field_behavior_pb2.FieldBehavior.Value("REQUIRED")) # Cephalopods are required. - squid = make_field(name='squid', options=required_options) + squid = make_field(name="squid", options=required_options) octopus = make_field( - name='octopus', + name="octopus", message=make_message( - name='Octopus', - fields=[make_field(name='mass', options=required_options)]), + name="Octopus", + fields=[make_field(name="mass", options=required_options)] + ), options=required_options) # Bivalves are optional. - clam = make_field(name='clam') + clam = make_field(name="clam") oyster = make_field( - name='oyster', + name="oyster", message=make_message( - name='Oyster', fields=[make_field(name='has_pearl')])) + name="Oyster", + fields=[make_field(name="has_pearl")] + ) + ) # Interleave required and optional fields to make sure # that, in the legacy flattening, required fields are always first. - request = make_message('request', fields=[squid, clam, octopus, oyster]) + request = make_message("request", fields=[squid, clam, octopus, oyster]) method = make_method( - name='CreateMolluscs', + name="CreateMolluscs", input_message=request, # Signatures should be ignored. - signatures=['squid,octopus.mass', 'squid,octopus,oyster.has_pearl']) + signatures=[ + "squid,octopus.mass", + "squid,octopus,oyster.has_pearl" + ] + ) # Use an ordered dict because ordering is important: # required fields should come first. - expected = collections.OrderedDict([('squid', squid), ('octopus', octopus), - ('clam', clam), ('oyster', oyster)]) + expected = collections.OrderedDict([ + ("squid", squid), + ("octopus", octopus), + ("clam", clam), + ("oyster", oyster) + ]) assert method.legacy_flattened_fields == expected def test_flattened_oneof_fields(): - mass_kg = make_field(name='mass_kg', oneof='mass', type=5) - mass_lbs = make_field(name='mass_lbs', oneof='mass', type=5) + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) - length_m = make_field(name='length_m', oneof='length', type=5) - length_f = make_field(name='length_f', oneof='length', type=5) + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) - color = make_field(name='color', type=5) + color = make_field(name="color", type=5) mantle = make_field( - name='mantle', + name="mantle", message=make_message( - name='Mantle', + name="Mantle", fields=( - make_field(name='color', type=5), + make_field(name="color", type=5), mass_kg, mass_lbs, ), ), ) request = make_message( - name='CreateMolluscReuqest', + name="CreateMolluscReuqest", fields=( length_m, length_f, @@ -528,27 +536,28 @@ def test_flattened_oneof_fields(): ), ) method = make_method( - name='CreateMollusc', + name="CreateMollusc", input_message=request, signatures=[ - 'length_m,', - 'length_f,', - 'mantle.mass_kg,', - 'mantle.mass_lbs,', - 'color', - ]) - - expected = {'mass': [mass_kg, mass_lbs], 'length': [length_m, length_f]} + "length_m,", + "length_f,", + "mantle.mass_kg,", + "mantle.mass_lbs,", + "color", + ] + ) + + expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} actual = method.flattened_oneof_fields() assert expected == actual # Check this method too becasue the setup is a lot of work. expected = { - 'color': 'color', - 'length_m': 'length_m', - 'length_f': 'length_f', - 'mass_kg': 'mantle.mass_kg', - 'mass_lbs': 'mantle.mass_lbs', + "color": "color", + "length_m": "length_m", + "length_f": "length_f", + "mass_kg": "mantle.mass_kg", + "mass_lbs": "mantle.mass_lbs", } actual = method.flattened_field_to_key assert expected == actual