Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add interceptor-like functionality to REST transport #1142

Merged
merged 5 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}

# Compile a registry of transports.
Expand All @@ -29,6 +30,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
software-dov marked this conversation as resolved.
Show resolved Hide resolved
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I imagined we'd be able to pass a list of interceptors that could be chained together, but understand that requires additional architecture here. We can handle the chaining within the body of the pre_/post_ functions.

Copy link
Contributor Author

@software-dov software-dov Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, that's easily done by something like this:

class InterceptorChainer:
    def __init__(self, chain):
        assert all(isinstance(i, RestInterceptor) for i in chain)
        # Make our own copy to prevent external modification
        self.chain = list(chain)
        
    def __getattr__(self, name):
        if name.startswith("pre_"):
            def pre(request, metadata):
                for i in self.chain:
                    request, metadata = getattr(i, name)(request, metadata)
                return request, metadata
                
            return pre
                
        elif name.startswith("post_"):
            def post(response):
                for i in self.chain:
                    response = getattr(i, name)(response)
                return response
            
            return post
            
        else:
            raise AttributeError(f"No such attribute: {name}")

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from google.api_core import grpc_helpers
from google.api_core import path_template
{% if service.has_lro %}
from google.api_core import future
from google.api_core import operation
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% if "rest" in opts.transport %}
Expand Down Expand Up @@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields():

{% endif %}{# required_fields #}

{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("null_interceptor", [True, False])
def test_{{ method_name }}_rest_interceptors(null_interceptor):
transport = transports.{{ service.name }}RestTransport(
credentials=ga_credentials.AnonymousCredentials(),
interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(),
)
client = {{ service.client_name }}(transport=transport)
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
{% if method.lro %}
mock.patch.object(operation.Operation, "_set_result_from_operation"), \
{% endif %}
{% if not method.void %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
{% endif %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
pre.assert_not_called()
{% if not method.void %}
post.assert_not_called()
{% endif %}

transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}

req.return_value = Response()
req.return_value.status_code = 200
req.return_value.request = PreparedRequest()
{% if not method.void %}
req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %}
{% endif %}

request = {{ method.input.ident }}()
metadata =[
("key", "val"),
("cephalopod", "squid"),
]
pre.return_value = request, metadata
{% if not method.void %}
post.return_value = {{ method.output.ident }}
{% endif %}

client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
{% if not method.void %}
post.assert_called_once()
{% endif %}
{% endif %}{# streaming #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}


Expand All @@ -34,6 +35,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Loading