Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(snippetgen): generate mock input for required fields #941

Merged
merged 17 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions gapic/samplegen/samplegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def build(
f"Resource {resource_typestr} has no pattern with params: {attr_name_str}"
)

return cls(base=base, body=attrs, single=None, pattern=pattern)
return cls(base=base, body=attrs, single=None, pattern=pattern,)


@dataclasses.dataclass
Expand Down Expand Up @@ -293,17 +293,22 @@ def preprocess_sample(sample, api_schema: api.API, rpc: wrappers.Method):
sample["module_name"] = api_schema.naming.versioned_module_name
sample["module_namespace"] = api_schema.naming.module_namespace

service = api_schema.services[sample["service"]]

# Assume the gRPC transport if the transport is not specified
sample.setdefault("transport", api.TRANSPORT_GRPC)
transport = sample.setdefault("transport", api.TRANSPORT_GRPC)

if sample["transport"] == api.TRANSPORT_GRPC_ASYNC:
sample["client_name"] = api_schema.services[sample["service"]
].async_client_name
else:
sample["client_name"] = api_schema.services[sample["service"]].client_name
is_async = transport == api.TRANSPORT_GRPC_ASYNC
sample["client_name"] = service.async_client_name if is_async else service.client_name

# the type of the request object passed to the rpc e.g, `ListRequest`
sample["request_type"] = rpc.input.ident.name
# the MessageType of the request object passed to the rpc e.g, `ListRequest`
sample["request_type"] = rpc.input

# If no request was specified in the config
# Add reasonable default values as placeholders
if "request" not in sample:
sample["request"] = generate_request_object(
api_schema, service, rpc.input)

# If no response was specified in the config
# Add reasonable defaults depending on the type of the sample
Expand Down Expand Up @@ -940,6 +945,58 @@ def parse_handwritten_specs(sample_configs: Sequence[str]) -> Generator[Dict[str
yield spec


def generate_request_object(api_schema: api.API, service: wrappers.Service, message: wrappers.MessageType, field_name_prefix: str = ""):
"""Generate dummy input for a given message.

Args:
api_schema (api.API): The schema that defines the API.
service (wrappers.Service): The service object the message belongs to.
message (wrappers.MessageType): The message to generate a request object for.
field_name_prefix (str): A prefix to attach to the field name in the request.

Returns:
List[Dict[str, Any]]: A list of dicts that can be turned into TransformedRequests.
"""
request: List[Dict[str, Any]] = []

request_fields: List[wrappers.Field] = []

# Choose the first option for each oneof
selected_oneofs: List[wrappers.Field] = [oneof_fields[0]
for oneof_fields in message.oneof_fields().values()]
request_fields = selected_oneofs + message.required_fields

for field in request_fields:
# TransformedRequest expects nested fields to be referenced like
# `destination.input_config.name`
field_name = ".".join([field_name_prefix, field.name]).lstrip('.')

# TODO(busunkim): Properly handle map fields
if field.is_primitive:
placeholder_value = field.mock_value_original_type
# If this field identifies a resource use the resource path
if service.resource_messages_dict.get(field.resource_reference):
placeholder_value = service.resource_messages_dict[
field.resource_reference].resource_path
request.append({"field": field_name, "value": placeholder_value})
busunkim96 marked this conversation as resolved.
Show resolved Hide resolved
elif field.enum:
# Choose the last enum value in the list since index 0 is often "unspecified"
request.append(
{"field": field_name, "value": field.enum.values[-1].name})
else:
# This is a message type, recurse
# TODO(busunkim): Some real world APIs have
# request objects are recursive.
# Reference `Field.mock_value` to ensure
# this always terminates.
request += generate_request_object(
api_schema, service, field.type,
field_name_prefix=field_name,
)

return request


def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, Any], None, None]:
"""Given an API, generate basic sample specs for each method.

Expand All @@ -964,8 +1021,7 @@ def generate_sample_specs(api_schema: api.API, *, opts) -> Generator[Dict[str, A
"sample_type": "standalone",
"rpc": rpc_name,
"transport": transport,
"request": [],
# response is populated in `preprocess_sample`
# `request` and `response` is populated in `preprocess_sample`
"service": f"{api_schema.naming.proto_package}.{service_name}",
"region_tag": region_tag,
"description": f"Snippet for {utils.to_snake_case(rpc_name)}"
Expand Down
112 changes: 97 additions & 15 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import dataclasses
import re
from itertools import chain
from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping,
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)
from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
Expand Down Expand Up @@ -89,6 +89,17 @@ def map(self) -> bool:
"""Return True if this field is a map, False otherwise."""
return bool(self.repeated and self.message and self.message.map)

@utils.cached_property
def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, List[Any], None]:
answer = self.primitive_mock() or None

# If this is a repeated field, then the mock answer should
# be a list.
if self.repeated:
answer = [answer]

return answer

@utils.cached_property
def mock_value(self) -> str:
visited_fields: Set["Field"] = set()
Expand All @@ -100,25 +111,13 @@ def mock_value(self) -> str:

return answer

def inner_mock(self, stack, visited_fields):
def inner_mock(self, stack, visited_fields) -> str:
"""Return a repr of a valid, usually truthy mock value."""
# For primitives, send a truthy value computed from the
# field name.
answer = 'None'
if isinstance(self.type, PrimitiveType):
if self.type.python_type == bool:
answer = 'True'
elif self.type.python_type == str:
answer = f"'{self.name}_value'"
elif self.type.python_type == bytes:
answer = f"b'{self.name}_blob'"
elif self.type.python_type == int:
answer = f'{sum([ord(i) for i in self.name])}'
elif self.type.python_type == float:
answer = f'0.{sum([ord(i) for i in self.name])}'
else: # Impossible; skip coverage checks.
raise TypeError('Unrecognized PrimitiveType. This should '
'never happen; please file an issue.')
answer = self.primitive_mock_as_str()

# If this is an enum, select the first truthy value (or the zero
# value if nothing else exists).
Expand Down Expand Up @@ -158,6 +157,45 @@ def inner_mock(self, stack, visited_fields):
# Done; return the mock value.
return answer

def primitive_mock(self) -> Union[bool, str, bytes, int, float, List[Any], None]:
"""Generate a valid mock for a primitive type. This function
returns the original (Python) type.
"""
answer: Union[bool, str, bytes, int, float, List[Any], None] = None

if not isinstance(self.type, PrimitiveType):
raise TypeError(f"'inner_mock_as_original_type' can only be used for"
f"PrimitiveType, but type is {self.type}")

else:
if self.type.python_type == bool:
answer = True
elif self.type.python_type == str:
answer = f"{self.name}_value"
elif self.type.python_type == bytes:
answer = bytes(f"{self.name}_blob", encoding="utf-8")
elif self.type.python_type == int:
answer = sum([ord(i) for i in self.name])
elif self.type.python_type == float:
name_sum = sum([ord(i) for i in self.name])
answer = name_sum * pow(10, -1 * len(str(name_sum)))
else: # Impossible; skip coverage checks.
raise TypeError('Unrecognized PrimitiveType. This should '
'never happen; please file an issue.')

return answer

def primitive_mock_as_str(self) -> str:
"""Like primitive mock, but return the mock as a string."""
answer = self.primitive_mock()

if isinstance(answer, str):
answer = f"'{answer}'"
else:
answer = str(answer)

return answer

@property
def proto_type(self) -> str:
"""Return the proto type constant to be used in templates."""
Expand Down Expand Up @@ -186,6 +224,17 @@ def required(self) -> bool:
return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in
self.options.Extensions[field_behavior_pb2.field_behavior])

@property
def resource_reference(self) -> Optional[str]:
"""Return a resource reference type if it exists.

This is only applicable for string fields.
Example: "translate.googleapis.com/Glossary"
"""
return (self.options.Extensions[resource_pb2.resource_reference].type
or self.options.Extensions[resource_pb2.resource_reference].child_type
or None)

@utils.cached_property
def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']:
"""Return the type of this field."""
Expand Down Expand Up @@ -286,6 +335,13 @@ def oneof_fields(self, include_optional=False):

return oneof_fields

@utils.cached_property
def required_fields(self) -> Sequence['Field']:
required_fields = [
field for field in self.fields.values() if field.required]

return required_fields

@utils.cached_property
def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
answer = tuple(
Expand Down Expand Up @@ -353,6 +409,11 @@ def resource_type(self) -> Optional[str]:
resource = self.options.Extensions[resource_pb2.resource]
return resource.type[resource.type.find('/') + 1:] if resource else None

@property
def resource_type_full_path(self) -> Optional[str]:
resource = self.options.Extensions[resource_pb2.resource]
return resource.type if resource else None

@property
def resource_path_args(self) -> Sequence[str]:
return self.PATH_ARG_RE.findall(self.resource_path or '')
Expand Down Expand Up @@ -1199,6 +1260,27 @@ def gen_indirect_resources_used(message):
)
)

@utils.cached_property
def resource_messages_dict(self) -> Dict[str, MessageType]:
"""Returns a dict from resource reference to
the message type. This *includes* the common resource messages.

Returns:
Dict[str, MessageType]: A mapping from resource path
string to the corresponding MessageType.
`{"locations.googleapis.com/Location": MessageType(...)}`
"""
service_resource_messages = {
r.resource_type_full_path: r for r in self.resource_messages}

# Add common resources
service_resource_messages.update(
(resource_path, resource.message_type)
for resource_path, resource in self.common_resources.items()
)

return service_resource_messages

@utils.cached_property
def any_client_streaming(self) -> bool:
return any(m.client_streaming for m in self.methods.values())
Expand Down
18 changes: 10 additions & 8 deletions gapic/templates/examples/feature_fragments.j2
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,19 @@ with open({{ print_string_formatting(statement["filename"])|trim }}, "wb") as f:
{% macro render_request_attr(base_name, attr) %}
{# Note: python code will have manipulated the value #}
{# to be the correct enum from the right module, if necessary. #}
{# Python is also responsible for verifying that each input parameter is unique,#}
{# Python is also responsible for verifying that each input parameter is unique, #}
{# no parameter is a reserved keyword #}
{% if attr.input_parameter %}

# {{ attr.input_parameter }} = {{ attr.value }}
{% if attr.value_is_file %}
with open({{ attr.input_parameter }}, "rb") as f:
{{ base_name }}["{{ attr.field }}"] = f.read()
{{ base_name }}.{{ attr.field }} = f.read()
{% else %}
{{ base_name }}["{{ attr.field }}"] = {{ attr.input_parameter }}
{{ base_name }}.{{ attr.field }} = {{ attr.input_parameter }}
{% endif %}
{% else %}
{{ base_name }}["{{ attr.field }}"] = {{ attr.value }}
{{ base_name }}.{{ attr.field }} = {{ attr.value }}
{% endif %}
{% endmacro %}

Expand All @@ -159,16 +160,17 @@ client = {{ module_name }}.{{ client_name }}()
{{ parameter_block.base }} = "{{parameter_block.pattern }}".format({{ formals|join(", ") }})
{% endwith %}
{% else %}{# End resource name construction #}
{{ parameter_block.base }} = {}
{{ parameter_block.base }} = {{ module_name }}.{{ request_type.get_field(parameter_block.base).type.name }}()
{% for attr in parameter_block.body %}
{{ render_request_attr(parameter_block.base, attr) }}
{{ render_request_attr(parameter_block.base, attr) -}}
{% endfor %}

{% endif %}
{% endfor %}
{% if not full_request.flattenable %}
request = {{ module_name }}.{{ request_type }}(
request = {{ module_name }}.{{ request_type.ident.name }}(
{% for parameter in full_request.request_list %}
{{ parameter.base }}={{ parameter.base if parameter.body else parameter.single }},
{{ parameter.base }}={{ parameter.base if parameter.body else parameter.single.value }},
{% endfor %}
)
{% endif %}
Expand Down
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def unit(session):
"--cov-report=term",
"--cov-fail-under=100",
path.join("tests", "unit"),
]
]
),
)

Expand Down Expand Up @@ -308,7 +308,7 @@ def snippetgen(session):

session.run(
"py.test",
"--quiet",
"-vv",
"tests/snippetgen"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ async def sample_analyze_iam_policy():
client = asset_v1.AssetServiceAsyncClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

request = asset_v1.AnalyzeIamPolicyRequest(
analysis_query=analysis_query,
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ async def sample_analyze_iam_policy_longrunning():
client = asset_v1.AssetServiceAsyncClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

output_config = asset_v1.IamPolicyAnalysisOutputConfig()
output_config.gcs_destination.uri = "uri_value"

request = asset_v1.AnalyzeIamPolicyLongrunningRequest(
analysis_query=analysis_query,
output_config=output_config,
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def sample_analyze_iam_policy_longrunning():
client = asset_v1.AssetServiceClient()

# Initialize request argument(s)
analysis_query = asset_v1.IamPolicyAnalysisQuery()
analysis_query.scope = "scope_value"

output_config = asset_v1.IamPolicyAnalysisOutputConfig()
output_config.gcs_destination.uri = "uri_value"

request = asset_v1.AnalyzeIamPolicyLongrunningRequest(
analysis_query=analysis_query,
output_config=output_config,
)

# Make the request
Expand Down
Loading