From 33757d88c968fef65332f5ebe0b876758f978ab0 Mon Sep 17 00:00:00 2001 From: esert-g <48071655+esert-g@users.noreply.github.com> Date: Wed, 12 Jan 2022 15:01:03 -0800 Subject: [PATCH] feat: retryable resource exhausted handling (#366) BigQuery Storage Read API will start returning retryable RESOURCE_EXHAUSTED errors in 2022 when certain concurrency limits are hit, so this PR adds some code to handle them. Tested with unit tests and system tests. System tests ran successfully on a test project that intentionally returns retryable RESOURCE_EXHAUSTED errors. Co-authored-by: Tim Swast --- google/cloud/bigquery_storage_v1/client.py | 20 ++- google/cloud/bigquery_storage_v1/reader.py | 72 ++++++-- .../cloud/bigquery_storage_v1beta2/client.py | 20 ++- tests/unit/test_reader_v1.py | 164 +++++++++++++----- tests/unit/test_reader_v1_arrow.py | 47 ++--- 5 files changed, 234 insertions(+), 89 deletions(-) diff --git a/google/cloud/bigquery_storage_v1/client.py b/google/cloud/bigquery_storage_v1/client.py index 75ef3834..05f91ae9 100644 --- a/google/cloud/bigquery_storage_v1/client.py +++ b/google/cloud/bigquery_storage_v1/client.py @@ -47,6 +47,7 @@ def read_rows( retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, metadata=(), + retry_delay_callback=None, ): """ Reads rows from the table in the format prescribed by the read @@ -108,6 +109,12 @@ def read_rows( specified, the timeout applies to each individual attempt. metadata (Optional[Sequence[Tuple[str, str]]]): Additional metadata that is provided to the method. + retry_delay_callback (Optional[Callable[[float], None]]): + If the client receives a retryable error that asks the client to + delay its next attempt and retry_delay_callback is not None, + BigQueryReadClient will call retry_delay_callback with the delay + duration (in seconds) before it starts sleeping until the next + attempt. Returns: ~google.cloud.bigquery_storage_v1.reader.ReadRowsStream: @@ -122,20 +129,15 @@ def read_rows( ValueError: If the parameters are invalid. """ gapic_client = super(BigQueryReadClient, self) - stream = gapic_client.read_rows( - read_stream=name, - offset=offset, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return reader.ReadRowsStream( - stream, + stream = reader.ReadRowsStream( gapic_client, name, offset, {"retry": retry, "timeout": timeout, "metadata": metadata}, + retry_delay_callback=retry_delay_callback, ) + stream._reconnect() + return stream class BigQueryWriteClient(big_query_write.BigQueryWriteClient): diff --git a/google/cloud/bigquery_storage_v1/reader.py b/google/cloud/bigquery_storage_v1/reader.py index beb1dbb5..2e387eb2 100644 --- a/google/cloud/bigquery_storage_v1/reader.py +++ b/google/cloud/bigquery_storage_v1/reader.py @@ -17,12 +17,14 @@ import collections import io import json +import time try: import fastavro except ImportError: # pragma: NO COVER fastavro = None import google.api_core.exceptions +import google.rpc.error_details_pb2 try: import pandas @@ -79,16 +81,17 @@ class ReadRowsStream(object): If the pandas and fastavro libraries are installed, use the :func:`~google.cloud.bigquery_storage_v1.reader.ReadRowsStream.to_dataframe()` method to parse all messages into a :class:`pandas.DataFrame`. + + This object should not be created directly, but is returned by + other methods in this library. """ - def __init__(self, wrapped, client, name, offset, read_rows_kwargs): + def __init__( + self, client, name, offset, read_rows_kwargs, retry_delay_callback=None + ): """Construct a ReadRowsStream. Args: - wrapped (Iterable[ \ - ~google.cloud.bigquery_storage.types.ReadRowsResponse \ - ]): - The ReadRows stream to read. client ( \ ~google.cloud.bigquery_storage_v1.services. \ big_query_read.BigQueryReadClient \ @@ -106,6 +109,12 @@ def __init__(self, wrapped, client, name, offset, read_rows_kwargs): read_rows_kwargs (dict): Keyword arguments to use when reconnecting to a ReadRows stream. + retry_delay_callback (Optional[Callable[[float], None]]): + If the client receives a retryable error that asks the client to + delay its next attempt and retry_delay_callback is not None, + ReadRowsStream will call retry_delay_callback with the delay + duration (in seconds) before it starts sleeping until the next + attempt. Returns: Iterable[ \ @@ -116,11 +125,12 @@ def __init__(self, wrapped, client, name, offset, read_rows_kwargs): # Make a copy of the read position so that we can update it without # mutating the original input. - self._wrapped = wrapped self._client = client self._name = name self._offset = offset self._read_rows_kwargs = read_rows_kwargs + self._retry_delay_callback = retry_delay_callback + self._wrapped = None def __iter__(self): """An iterable of messages. @@ -131,9 +141,12 @@ def __iter__(self): ]: A sequence of row messages. """ - # Infinite loop to reconnect on reconnectable errors while processing # the row stream. + + if self._wrapped is None: + self._reconnect() + while True: try: for message in self._wrapped: @@ -152,14 +165,53 @@ def __iter__(self): except _STREAM_RESUMPTION_EXCEPTIONS: # Transient error, so reconnect to the stream. pass + except Exception as exc: + if not self._resource_exhausted_exception_is_retryable(exc): + raise self._reconnect() def _reconnect(self): """Reconnect to the ReadRows stream using the most recent offset.""" - self._wrapped = self._client.read_rows( - read_stream=self._name, offset=self._offset, **self._read_rows_kwargs - ) + while True: + try: + self._wrapped = self._client.read_rows( + read_stream=self._name, + offset=self._offset, + **self._read_rows_kwargs + ) + break + except Exception as exc: + if not self._resource_exhausted_exception_is_retryable(exc): + raise + + def _resource_exhausted_exception_is_retryable(self, exc): + if isinstance(exc, google.api_core.exceptions.ResourceExhausted): + # ResourceExhausted errors are only retried if a valid + # RetryInfo is provided with the error. + # + # TODO: Remove hasattr logic when we require google-api-core >= 2.2.0. + # ResourceExhausted added details/_details in google-api-core 2.2.0. + details = None + if hasattr(exc, "details"): + details = exc.details + elif hasattr(exc, "_details"): + details = exc._details + if details is not None: + for detail in details: + if isinstance(detail, google.rpc.error_details_pb2.RetryInfo): + retry_delay = detail.retry_delay + if retry_delay is not None: + delay = max( + 0, + float(retry_delay.seconds) + + (float(retry_delay.nanos) / 1e9), + ) + if self._retry_delay_callback: + self._retry_delay_callback(delay) + time.sleep(delay) + return True + return False def rows(self, read_session=None): """Iterate over all rows in the stream. diff --git a/google/cloud/bigquery_storage_v1beta2/client.py b/google/cloud/bigquery_storage_v1beta2/client.py index 00bff3ff..0dc428b9 100644 --- a/google/cloud/bigquery_storage_v1beta2/client.py +++ b/google/cloud/bigquery_storage_v1beta2/client.py @@ -48,6 +48,7 @@ def read_rows( retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, metadata=(), + retry_delay_callback=None, ): """ Reads rows from the table in the format prescribed by the read @@ -109,6 +110,12 @@ def read_rows( specified, the timeout applies to each individual attempt. metadata (Optional[Sequence[Tuple[str, str]]]): Additional metadata that is provided to the method. + retry_delay_callback (Optional[Callable[[float], None]]): + If the client receives a retryable error that asks the client to + delay its next attempt and retry_delay_callback is not None, + BigQueryReadClient will call retry_delay_callback with the delay + duration (in seconds) before it starts sleeping until the next + attempt. Returns: ~google.cloud.bigquery_storage_v1.reader.ReadRowsStream: @@ -123,20 +130,15 @@ def read_rows( ValueError: If the parameters are invalid. """ gapic_client = super(BigQueryReadClient, self) - stream = gapic_client.read_rows( - read_stream=name, - offset=offset, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return reader.ReadRowsStream( - stream, + stream = reader.ReadRowsStream( gapic_client, name, offset, {"retry": retry, "timeout": timeout, "metadata": metadata}, + retry_delay_callback=retry_delay_callback, ) + stream._reconnect() + return stream class BigQueryWriteClient(big_query_write.BigQueryWriteClient): diff --git a/tests/unit/test_reader_v1.py b/tests/unit/test_reader_v1.py index 59292843..826e8ea7 100644 --- a/tests/unit/test_reader_v1.py +++ b/tests/unit/test_reader_v1.py @@ -27,6 +27,7 @@ import google.api_core.exceptions from google.cloud.bigquery_storage import types from .helpers import SCALAR_COLUMNS, SCALAR_COLUMN_NAMES, SCALAR_BLOCKS +import google.rpc.error_details_pb2 PROJECT = "my-project" @@ -97,6 +98,29 @@ def _pages_w_resumable_internal_error(avro_blocks): ) +def _pages_w_nonresumable_resource_exhausted_error(avro_blocks): + for block in avro_blocks: + yield block + raise google.api_core.exceptions.ResourceExhausted( + "RESOURCE_EXHAUSTED: do not retry" + ) + + +def _pages_w_resumable_resource_exhausted_error( + avro_blocks, delay_seconds, delay_nanos +): + for block in avro_blocks: + yield block + retry_info = google.rpc.error_details_pb2.RetryInfo() + retry_info.retry_delay.seconds = delay_seconds + retry_info.retry_delay.nanos = delay_nanos + error = google.api_core.exceptions.ResourceExhausted( + "RESOURCE_EXHAUSTED: retry later" + ) + error._details = (retry_info,) + raise error + + def _pages_w_unavailable(pages): for page in pages: yield page @@ -144,7 +168,8 @@ def test_avro_rows_raises_import_error( monkeypatch.setattr(mut, "fastavro", None) avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) rows = iter(reader.rows()) # Since session isn't passed in, reader doesn't know serialization type @@ -159,7 +184,8 @@ def test_rows_no_schema_set_raises_type_error( avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) avro_blocks[0].avro_schema = None - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) rows = iter(reader.rows()) # Since session isn't passed in, reader doesn't know serialization type @@ -169,7 +195,8 @@ def test_rows_no_schema_set_raises_type_error( def test_rows_w_empty_stream(class_under_test, mock_gapic_client): - reader = class_under_test([], mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = [] + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.rows() assert tuple(got) == () @@ -177,8 +204,8 @@ def test_rows_w_empty_stream(class_under_test, mock_gapic_client): def test_rows_w_scalars(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) - - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) got = tuple(reader.rows()) expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) @@ -198,22 +225,22 @@ def test_rows_w_timeout(class_under_test, mock_gapic_client): bq_blocks_2 = [[{"int_col": 567}, {"int_col": 789}], [{"int_col": 890}]] avro_blocks_2 = _bq_to_avro_blocks(bq_blocks_2, avro_schema) - mock_gapic_client.read_rows.return_value = avro_blocks_2 + mock_gapic_client.read_rows.side_effect = ( + avro_blocks_1, + avro_blocks_2, + ) reader = class_under_test( - avro_blocks_1, - mock_gapic_client, - "teststream", - 0, - {"metadata": {"test-key": "test-value"}}, + mock_gapic_client, "teststream", 0, {"metadata": {"test-key": "test-value"}}, ) with pytest.raises(google.api_core.exceptions.DeadlineExceeded): list(reader.rows()) - # Don't reconnect on DeadlineException. This allows user-specified timeouts - # to be respected. - mock_gapic_client.read_rows.assert_not_called() + # Don't reconnect on DeadlineException so user-specified timeouts + # are respected. This requires client.read_rows to be called + # exactly once which fails with DeadlineException. + mock_gapic_client.read_rows.assert_called_once() def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client): @@ -223,15 +250,43 @@ def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client) avro_blocks = _pages_w_nonresumable_internal_error( _bq_to_avro_blocks(bq_blocks, avro_schema) ) - - reader = class_under_test(avro_blocks, mock_gapic_client, "teststream", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "teststream", 0, {}) with pytest.raises( google.api_core.exceptions.InternalServerError, match="nonresumable error" ): list(reader.rows()) - mock_gapic_client.read_rows.assert_not_called() + mock_gapic_client.read_rows.assert_called_once() + + +def test_rows_w_nonresumable_resource_exhausted_error( + class_under_test, mock_gapic_client +): + bq_columns = [{"name": "int_col", "type": "int64"}] + avro_schema = _bq_to_avro_schema(bq_columns) + bq_blocks = [[{"int_col": 1024}, {"int_col": 512}], [{"int_col": 256}]] + avro_blocks = _pages_w_nonresumable_resource_exhausted_error( + _bq_to_avro_blocks(bq_blocks, avro_schema) + ) + + retry_delay_num_calls = 0 + + def retry_delay_callback(delay): + nonlocal retry_delay_num_calls + retry_delay_num_calls += 1 + + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "teststream", 0, {}) + + with pytest.raises( + google.api_core.exceptions.ResourceExhausted, match="do not retry" + ): + list(reader.rows()) + + mock_gapic_client.read_rows.assert_called_once() + assert retry_delay_num_calls == 0 def test_rows_w_reconnect(class_under_test, mock_gapic_client): @@ -249,20 +304,37 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): bq_blocks_3 = [[{"int_col": -1}, {"int_col": -2}], [{"int_col": -4}]] avro_blocks_3 = _pages_w_unknown(_bq_to_avro_blocks(bq_blocks_3, avro_schema)) bq_blocks_4 = [[{"int_col": 567}, {"int_col": 789}], [{"int_col": 890}]] - avro_blocks_4 = _bq_to_avro_blocks(bq_blocks_4, avro_schema) + delay_seconds = 1 + delay_nanos = 234 + avro_blocks_4 = _pages_w_resumable_resource_exhausted_error( + _bq_to_avro_blocks(bq_blocks_4, avro_schema), delay_seconds, delay_nanos + ) + bq_blocks_5 = [[{"int_col": 859}, {"int_col": 231}], [{"int_col": 777}]] + avro_blocks_5 = _bq_to_avro_blocks(bq_blocks_5, avro_schema) mock_gapic_client.read_rows.side_effect = ( + avro_blocks_1, avro_blocks_2, avro_blocks_3, avro_blocks_4, + avro_blocks_5, ) + retry_delay_num_calls = 0 + retry_delay = 0 + + def retry_delay_callback(delay): + nonlocal retry_delay_num_calls + nonlocal retry_delay + retry_delay_num_calls += 1 + retry_delay = delay + reader = class_under_test( - avro_blocks_1, mock_gapic_client, "teststream", 0, {"metadata": {"test-key": "test-value"}}, + retry_delay_callback=retry_delay_callback, ) got = reader.rows() @@ -272,6 +344,7 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): itertools.chain.from_iterable(bq_blocks_2), itertools.chain.from_iterable(bq_blocks_3), itertools.chain.from_iterable(bq_blocks_4), + itertools.chain.from_iterable(bq_blocks_5), ) ) @@ -282,9 +355,14 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): mock_gapic_client.read_rows.assert_any_call( read_stream="teststream", offset=7, metadata={"test-key": "test-value"} ) - mock_gapic_client.read_rows.assert_called_with( + mock_gapic_client.read_rows.assert_any_call( read_stream="teststream", offset=10, metadata={"test-key": "test-value"} ) + mock_gapic_client.read_rows.assert_called_with( + read_stream="teststream", offset=13, metadata={"test-key": "test-value"} + ) + assert retry_delay_num_calls == 1 + assert retry_delay == delay_seconds + (delay_nanos / 1e9) def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): @@ -298,14 +376,13 @@ def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): bq_blocks_2 = [[{"int_col": 567}, {"int_col": 789}], [{"int_col": 890}]] avro_blocks_2 = _bq_to_avro_blocks(bq_blocks_2, avro_schema) - mock_gapic_client.read_rows.return_value = avro_blocks_2 + mock_gapic_client.read_rows.side_effect = ( + _pages_w_unavailable(avro_blocks_1), + avro_blocks_2, + ) reader = class_under_test( - _pages_w_unavailable(avro_blocks_1), - mock_gapic_client, - "teststream", - 0, - {"metadata": {"test-key": "test-value"}}, + mock_gapic_client, "teststream", 0, {"metadata": {"test-key": "test-value"}}, ) got = reader.rows() pages = iter(got.pages) @@ -341,7 +418,8 @@ def test_to_dataframe_no_pandas_raises_import_error( avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): reader.to_dataframe() @@ -359,7 +437,8 @@ def test_to_dataframe_no_schema_set_raises_type_error( avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) avro_blocks[0].avro_schema = None - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) rows = reader.rows() # Since session isn't passed in, reader doesn't know serialization type @@ -368,11 +447,12 @@ def test_to_dataframe_no_schema_set_raises_type_error( rows.to_dataframe() -def test_to_dataframe_w_scalars(class_under_test): +def test_to_dataframe_w_scalars(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.to_dataframe() expected = pandas.DataFrame( @@ -397,7 +477,7 @@ def test_to_dataframe_w_scalars(class_under_test): ) -def test_to_dataframe_w_dtypes(class_under_test): +def test_to_dataframe_w_dtypes(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema( [ {"name": "bigfloat", "type": "float64"}, @@ -410,7 +490,8 @@ def test_to_dataframe_w_dtypes(class_under_test): ] avro_blocks = _bq_to_avro_blocks(blocks, avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( @@ -426,11 +507,12 @@ def test_to_dataframe_w_dtypes(class_under_test): ) -def test_to_dataframe_empty_w_scalars_avro(class_under_test): +def test_to_dataframe_empty_w_scalars_avro(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks([], avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) @@ -458,7 +540,8 @@ def test_to_dataframe_empty_w_dtypes_avro(class_under_test, mock_gapic_client): ) read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks([], avro_schema) - reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = avro_blocks + reader = class_under_test(mock_gapic_client, "", 0, {}) # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) @@ -490,14 +573,13 @@ def test_to_dataframe_by_page(class_under_test, mock_gapic_client): avro_blocks_1 = _bq_to_avro_blocks(bq_blocks_1, avro_schema) avro_blocks_2 = _bq_to_avro_blocks(bq_blocks_2, avro_schema) - mock_gapic_client.read_rows.return_value = avro_blocks_2 + mock_gapic_client.read_rows.side_effect = ( + _pages_w_unavailable(avro_blocks_1), + avro_blocks_2, + ) reader = class_under_test( - _pages_w_unavailable(avro_blocks_1), - mock_gapic_client, - "teststream", - 0, - {"metadata": {"test-key": "test-value"}}, + mock_gapic_client, "teststream", 0, {"metadata": {"test-key": "test-value"}}, ) got = reader.rows() pages = iter(got.pages) diff --git a/tests/unit/test_reader_v1_arrow.py b/tests/unit/test_reader_v1_arrow.py index 02c7b80a..9cecb9d2 100644 --- a/tests/unit/test_reader_v1_arrow.py +++ b/tests/unit/test_reader_v1_arrow.py @@ -131,7 +131,8 @@ def test_pyarrow_rows_raises_import_error( monkeypatch.setattr(mut, "pyarrow", None) arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) rows = iter(reader.rows()) # Since session isn't passed in, reader doesn't know serialization type @@ -146,7 +147,8 @@ def test_to_arrow_no_pyarrow_raises_import_error( monkeypatch.setattr(mut, "pyarrow", None) arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): reader.to_arrow() @@ -158,10 +160,11 @@ def test_to_arrow_no_pyarrow_raises_import_error( next(reader.rows().pages).to_arrow() -def test_to_arrow_w_scalars_arrow(class_under_test): +def test_to_arrow_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) actual_table = reader.to_arrow() expected_table = pyarrow.Table.from_batches( _bq_to_arrow_batch_objects(SCALAR_BLOCKS, arrow_schema) @@ -169,11 +172,11 @@ def test_to_arrow_w_scalars_arrow(class_under_test): assert actual_table == expected_table -def test_to_dataframe_w_scalars_arrow(class_under_test): +def test_to_dataframe_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) - - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.to_dataframe() expected = pandas.DataFrame( @@ -187,7 +190,8 @@ def test_to_dataframe_w_scalars_arrow(class_under_test): def test_rows_w_empty_stream_arrow(class_under_test, mock_gapic_client): - reader = class_under_test([], mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = [] + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.rows() assert tuple(got) == () @@ -195,8 +199,8 @@ def test_rows_w_empty_stream_arrow(class_under_test, mock_gapic_client): def test_rows_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) - - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) got = tuple( dict((key, value.as_py()) for key, value in row_dict.items()) for row_dict in reader.rows() @@ -206,7 +210,7 @@ def test_rows_w_scalars_arrow(class_under_test, mock_gapic_client): assert got == expected -def test_to_dataframe_w_dtypes_arrow(class_under_test): +def test_to_dataframe_w_dtypes_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema( [ {"name": "bigfloat", "type": "float64"}, @@ -218,8 +222,8 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): [{"bigfloat": 3.75, "lilfloat": 11.0}], ] arrow_batches = _bq_to_arrow_batches(blocks, arrow_schema) - - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( @@ -235,11 +239,12 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): ) -def test_to_dataframe_empty_w_scalars_arrow(class_under_test): +def test_to_dataframe_empty_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches([], arrow_schema) - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) @@ -267,7 +272,8 @@ def test_to_dataframe_empty_w_dtypes_arrow(class_under_test, mock_gapic_client): ) read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches([], arrow_schema) - reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + mock_gapic_client.read_rows.return_value = arrow_batches + reader = class_under_test(mock_gapic_client, "", 0, {}) # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) @@ -309,11 +315,12 @@ def test_to_dataframe_by_page_arrow(class_under_test, mock_gapic_client): batch_1 = _bq_to_arrow_batches(bq_blocks_1, arrow_schema) batch_2 = _bq_to_arrow_batches(bq_blocks_2, arrow_schema) - mock_gapic_client.read_rows.return_value = batch_2 - - reader = class_under_test( - _pages_w_unavailable(batch_1), mock_gapic_client, "", 0, {} + mock_gapic_client.read_rows.side_effect = ( + _pages_w_unavailable(batch_1), + batch_2, ) + + reader = class_under_test(mock_gapic_client, "", 0, {}) got = reader.rows() pages = iter(got.pages)