Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass metadata to pagers #470

Merged
merged 10 commits into from
Jul 7, 2020
5 changes: 5 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def has_lro(self) -> bool:
"""Return whether the service has a long-running method."""
return any([m.lro for m in self.methods.values()])

@property
def has_pagers(self) -> bool:
"""Return whether the service has paged methods."""
return any([m.paged_result_field for m in self.methods.values()])

software-dov marked this conversation as resolved.
Show resolved Hide resolved
@property
def host(self) -> str:
"""Return the hostname for this service, if specified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class {{ service.async_client_name }}:
method=rpc,
request=request,
response=response,
metadata=metadata,
)
{%- endif %}
{%- if not method.void %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
method=rpc,
request=request,
response=response,
metadata=metadata,
)
{%- endif %}
{%- if not method.void %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{# This lives within the loop in order to ensure that this template
is empty if there are no paged methods.
-#}
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple

{% filter sort_lines -%}
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
Expand Down Expand Up @@ -35,10 +35,11 @@ class {{ method.name }}Pager:
the most recent response is retained, and thus used for attribute lookup.
"""
def __init__(self,
method: Callable[[{{ method.input.ident }}],
{{ method.output.ident }}],
method: Callable[..., {{ method.output.ident }}],
request: {{ method.input.ident }},
response: {{ method.output.ident }}):
response: {{ method.output.ident }},
*,
metadata: Sequence[Tuple[str, str]] = ()):
"""Instantiate the pager.

Args:
Expand All @@ -48,10 +49,13 @@ class {{ method.name }}Pager:
The initial request object.
response (:class:`{{ method.output.ident.sphinx }}`):
The initial response object.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = {{ method.input.ident }}(request)
self._response = response
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
Expand All @@ -61,7 +65,7 @@ class {{ method.name }}Pager:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request)
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
Expand Down Expand Up @@ -90,10 +94,11 @@ class {{ method.name }}AsyncPager:
the most recent response is retained, and thus used for attribute lookup.
"""
def __init__(self,
method: Callable[[{{ method.input.ident }}],
Awaitable[{{ method.output.ident }}]],
method: Callable[..., Awaitable[{{ method.output.ident }}]],
request: {{ method.input.ident }},
response: {{ method.output.ident }}):
response: {{ method.output.ident }},
*,
metadata: Sequence[Tuple[str, str]] = ()):
"""Instantiate the pager.

Args:
Expand All @@ -103,10 +108,13 @@ class {{ method.name }}AsyncPager:
The initial request object.
response (:class:`{{ method.output.ident.sphinx }}`):
The initial response object.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = {{ method.input.ident }}(request)
self._response = response
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
Expand All @@ -116,7 +124,7 @@ class {{ method.name }}AsyncPager:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request)
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ from google.api_core import future
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% endif -%}
{% if service.has_pagers -%}
from google.api_core import gapic_v1
{% endif -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.ref_types
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
Expand Down Expand Up @@ -635,9 +638,24 @@ def test_{{ method.name|snake_case }}_pager():
),
RuntimeError,
)
results = [i for i in client.{{ method.name|snake_case }}(
request={},
)]

metadata = ()
{% if method.field_headers -%}
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
{%- if not method.client_streaming %}
('{{ field_header }}', ''),
{%- endif %}
{%- endfor %}
)),
)
{% endif -%}
pager = client.{{ method.name|snake_case }}(request={})

assert pager._metadata == metadata

results = [i for i in pager]
assert len(results) == 6
assert all(isinstance(i, {{ method.paged_result_field.message.ident }})
for i in results)
Expand Down