diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 index 1241886b63..88d196a7c2 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 @@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport {% endif %} {% if 'rest' in opts.transport %} from .rest import {{ service.name }}RestTransport +from .rest import {{ service.name }}RestInterceptor {% endif %} # Compile a registry of transports. @@ -29,6 +30,7 @@ __all__ = ( {% endif %} {% if 'rest' in opts.transport %} '{{ service.name }}RestTransport', + '{{ service.name }}RestInterceptor', {% endif %} ) {% endblock %} diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 index 488646be12..a4dc7c61e3 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( rest_version=requests_version, ) + +class {{ service.name }}RestInterceptor: + """Interceptor for {{ service.name }}. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the {{ service.name }}RestTransport. + + .. code-block: + class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor): +{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(response): + logging.log(f"Received response: {response}") + {% endif %} + +{% endfor %} + transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor()) + client = {{ service.client_name }}(transport=transport) + + + """ + {% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the request or metadata + before they are sent to the {{ service.name }} server. + """ + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}: + """Post-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the response + after it is returned by the {{ service.name }} server but before + it is returned to user code. + """ + return response + {% endif %} + + {% endfor %} + + @dataclasses.dataclass class {{service.name}}RestStub: _session: AuthorizedSession _host: str + _interceptor: {{ service.name }}RestInterceptor + class {{service.name}}RestTransport({{service.name}}Transport): """REST backend transport for {{ service.name }}. @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO, always_use_jwt_access: Optional[bool]=False, url_scheme: str='https', + interceptor: Optional[{{ service.name }}RestInterceptor] = None, ) -> None: """Instantiate the transport. @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endif %} if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or {{ service.name }}RestInterceptor() self._prep_wrapped_messages(client_info) {% if service.has_lro %} @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): }, {% endfor %}{# rule in method.http_options #} ] - + request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata) request_kwargs = {{method.input.ident}}.to_dict(request) transcoded_request = path_template.transcode( http_options, **request_kwargs) @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not method.void %} # Return the response {% if method.lro %} - return_op = operations_pb2.Operation() - json_format.Parse(response.content, return_op, ignore_unknown_fields=True) - return return_op + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) {% else %} - return {{method.output.ident}}.from_json( + resp = {{method.output.ident}}.from_json( response.content, ignore_unknown_fields=True ) - {% endif %}{# method.lro #} + resp = self._interceptor.post_{{ method.name|snake_case }}(resp) + return resp {% endif %}{# method.void #} {% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #} {% if not method.http_options %} @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {{method.output.ident}}]: stub = self._STUBS.get("{{method.name | snake_case}}") if not stub: - stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host) + stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor) return stub diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index af7d28335d..a7934d84e8 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -35,6 +35,7 @@ from google.api_core import grpc_helpers from google.api_core import path_template {% if service.has_lro %} from google.api_core import future +from google.api_core import operation from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% if "rest" in opts.transport %} @@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields(): {% endif %}{# required_fields #} +{% if not (method.server_streaming or method.client_streaming) %} +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_{{ method_name }}_rest_interceptors(null_interceptor): + transport = transports.{{ service.name }}RestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(), + ) + client = {{ service.client_name }}(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + {% if method.lro %} + mock.patch.object(operation.Operation, "_set_result_from_operation"), \ + {% endif %} + {% if not method.void %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \ + {% endif %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre: + pre.assert_not_called() + {% if not method.void %} + post.assert_not_called() + {% endif %} + + transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},} + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + {% if not method.void %} + req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %} + {% endif %} + + request = {{ method.input.ident }}() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + {% if not method.void %} + post.return_value = {{ method.output.ident }} + {% endif %} + + client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + {% if not method.void %} + post.assert_called_once() + {% endif %} +{% endif %}{# streaming #} + def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}): client = {{ service.client_name }}( diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 index 107e2bd4e8..66be2e5c29 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 @@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport {% endif %} {% if 'rest' in opts.transport %} from .rest import {{ service.name }}RestTransport +from .rest import {{ service.name }}RestInterceptor {% endif %} @@ -34,6 +35,7 @@ __all__ = ( {% endif %} {% if 'rest' in opts.transport %} '{{ service.name }}RestTransport', + '{{ service.name }}RestInterceptor', {% endif %} ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 488646be12..b208c0940f 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( rest_version=requests_version, ) + +class {{ service.name }}RestInterceptor: + """Interceptor for {{ service.name }}. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the {{ service.name }}RestTransport. + + .. code-block: + class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor): + {% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(response): + logging.log(f"Received response: {response}") + {% endif %} + +{% endfor %} + transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor()) + client = {{ service.client_name }}(transport=transport) + + + """ + {% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the request or metadata + before they are sent to the {{ service.name }} server. + """ + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}: + """Post-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the response + after it is returned by the {{ service.name }} server but before + it is returned to user code. + """ + return response + {% endif %} + + {% endfor %} + + @dataclasses.dataclass class {{service.name}}RestStub: _session: AuthorizedSession _host: str + _interceptor: {{ service.name }}RestInterceptor + class {{service.name}}RestTransport({{service.name}}Transport): """REST backend transport for {{ service.name }}. @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO, always_use_jwt_access: Optional[bool]=False, url_scheme: str='https', + interceptor: Optional[{{ service.name }}RestInterceptor] = None, ) -> None: """Instantiate the transport. @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endif %} if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or {{ service.name }}RestInterceptor() self._prep_wrapped_messages(client_info) {% if service.has_lro %} @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): }, {% endfor %}{# rule in method.http_options #} ] - + request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata) request_kwargs = {{method.input.ident}}.to_dict(request) transcoded_request = path_template.transcode( http_options, **request_kwargs) @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not method.void %} # Return the response {% if method.lro %} - return_op = operations_pb2.Operation() - json_format.Parse(response.content, return_op, ignore_unknown_fields=True) - return return_op + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) {% else %} - return {{method.output.ident}}.from_json( + resp = {{method.output.ident}}.from_json( response.content, ignore_unknown_fields=True ) - {% endif %}{# method.lro #} + resp = self._interceptor.post_{{ method.name|snake_case }}(resp) + return resp {% endif %}{# method.void #} {% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #} {% if not method.http_options %} @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {{method.output.ident}}]: stub = self._STUBS.get("{{method.name | snake_case}}") if not stub: - stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host) + stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor) return stub diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 56cdbc6287..cdee5b7697 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -39,6 +39,7 @@ from google.api_core import grpc_helpers_async from google.api_core import path_template {% if service.has_lro %} from google.api_core import future +from google.api_core import operation from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% if "rest" in opts.transport %} @@ -1515,6 +1516,57 @@ def test_{{ method_name }}_rest_unset_required_fields(): {% endif %}{# required_fields #} +{% if not (method.server_streaming or method.client_streaming) %} +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_{{ method_name }}_rest_interceptors(null_interceptor): + transport = transports.{{ service.name }}RestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(), + ) + client = {{ service.client_name }}(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + {% if method.lro %} + mock.patch.object(operation.Operation, "_set_result_from_operation"), \ + {% endif %} + {% if not method.void %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \ + {% endif %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre: + pre.assert_not_called() + {% if not method.void %} + post.assert_not_called() + {% endif %} + + transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},} + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + {% if not method.void %} + req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %} + {% endif %} + + request = {{ method.input.ident }}() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + {% if not method.void %} + post.return_value = {{ method.output.ident }} + {% endif %} + + client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + {% if not method.void %} + post.assert_called_once() + {% endif %} + +{% endif %}{# streaming #} + + def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}): client = {{ service.client_name }}( credentials=ga_credentials.AnonymousCredentials(), @@ -1829,7 +1881,7 @@ def test_credentials_transport_error(): client_options={"credentials_file": "credentials.json"}, transport=transport, ) - + # It is an error to provide an api_key and a transport instance. transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=ga_credentials.AnonymousCredentials(), @@ -1841,7 +1893,7 @@ def test_credentials_transport_error(): client_options=options, transport=transport, ) - + # It is an error to provide an api_key and a credential. options = mock.Mock() options.api_key = "api_key" @@ -2141,6 +2193,8 @@ def test_{{ service.name|snake_case }}_rest_lro_client(): # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client {%- endif %} + + {% endif %} {# rest #} def test_{{ service.name|snake_case }}_host_no_port(): diff --git a/noxfile.py b/noxfile.py index 6154ea94bb..a9df7d65d4 100644 --- a/noxfile.py +++ b/noxfile.py @@ -310,12 +310,17 @@ def run_showcase_unit_tests(session, fail_under=100): # Run the tests. session.run( "py.test", - "-n=auto", - "--quiet", - "--cov=google", - "--cov-append", - f"--cov-fail-under={str(fail_under)}", - *(session.posargs or [path.join("tests", "unit")]), + *( + session.posargs + or [ + "-n=auto", + "--quiet", + "--cov=google", + "--cov-append", + f"--cov-fail-under={str(fail_under)}", + path.join("tests", "unit"), + ] + ), ) diff --git a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py index a2747714b8..dd4f527b9c 100644 --- a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py +++ b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template diff --git a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py index e89e2e73fd..b189511ab7 100644 --- a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py +++ b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template