diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index dca9f7962..70e601714 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -2997,9 +2997,9 @@ def _rows_page_start(iterator, page, response): page._columns = _row_iterator_page_columns(iterator._schema, response) total_rows = response.get("totalRows") + # Don't reset total_rows if it's not present in the next API response. if total_rows is not None: - total_rows = int(total_rows) - iterator._total_rows = total_rows + iterator._total_rows = int(total_rows) # pylint: enable=unused-argument diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 85f335dd1..9b3d4fe84 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -2201,9 +2201,18 @@ def test_iterate_with_cached_first_page(self): path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) row_iterator = self._make_one( - _mock_client(), api_request, path, schema, first_page_response=first_page + _mock_client(), + api_request, + path, + schema, + first_page_response=first_page, + total_rows=4, ) + self.assertEqual(row_iterator.total_rows, 4) rows = list(row_iterator) + # Total rows should be maintained, even though subsequent API calls + # don't include it. + self.assertEqual(row_iterator.total_rows, 4) self.assertEqual(len(rows), 4) self.assertEqual(rows[0].age, 27) self.assertEqual(rows[1].age, 28)