From 34704c322df0dc0899be5bd67278e121c7e68a90 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Thu, 28 Mar 2019 17:05:15 -0700 Subject: [PATCH] Make total_rows available on RowIterator before iteration After running a query, the total number of rows is available from the call to the getQueryResults API. This commit plumbs the total rows through to the faux Table created in QueryJob.results and then on through to the RowIterator created by list_rows. Also, call get_table in list_rows... TODO: split that out to a separate PR --- bigquery/docs/snippets.py | 3 +- bigquery/google/cloud/bigquery/job.py | 1 + bigquery/google/cloud/bigquery/table.py | 8 +- bigquery/tests/unit/test_client.py | 27 ++++--- bigquery/tests/unit/test_job.py | 21 +++++- bigquery/tests/unit/test_table.py | 98 +++++++++++++++---------- 6 files changed, 102 insertions(+), 56 deletions(-) diff --git a/bigquery/docs/snippets.py b/bigquery/docs/snippets.py index 00569c40af189..5d0847e5b8923 100644 --- a/bigquery/docs/snippets.py +++ b/bigquery/docs/snippets.py @@ -2246,8 +2246,7 @@ def test_client_query_total_rows(client, capsys): location="US", ) # API request - starts the query - results = query_job.result() # Waits for query to complete. - next(iter(results)) # Fetch the first page of results, which contains total_rows. + results = query_job.result() # Wait for query to complete. print("Got {} rows.".format(results.total_rows)) # [END bigquery_query_total_rows] diff --git a/bigquery/google/cloud/bigquery/job.py b/bigquery/google/cloud/bigquery/job.py index bc87f109a4843..553905140755a 100644 --- a/bigquery/google/cloud/bigquery/job.py +++ b/bigquery/google/cloud/bigquery/job.py @@ -2808,6 +2808,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY): schema = self._query_results.schema dest_table_ref = self.destination dest_table = Table(dest_table_ref, schema=schema) + dest_table._properties["numRows"] = self._query_results.total_rows return self._client.list_rows(dest_table, retry=retry) def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None): diff --git a/bigquery/google/cloud/bigquery/table.py b/bigquery/google/cloud/bigquery/table.py index 7b924817eaca7..36ccf96d12866 100644 --- a/bigquery/google/cloud/bigquery/table.py +++ b/bigquery/google/cloud/bigquery/table.py @@ -1300,7 +1300,11 @@ def __init__( ) self._schema = schema self._field_to_index = _helpers._field_to_index_mapping(schema) + self._total_rows = None + if table is not None and hasattr(table, "num_rows"): + self._total_rows = table.num_rows + self._page_size = page_size self._table = table self._selected_fields = selected_fields @@ -1422,9 +1426,7 @@ def _get_progress_bar(self, progress_bar_type): desc=description, total=self.total_rows, unit=unit ) elif progress_bar_type == "tqdm_gui": - return tqdm.tqdm_gui( - desc=description, total=self.total_rows, unit=unit - ) + return tqdm.tqdm_gui(desc=description, total=self.total_rows, unit=unit) except (KeyError, TypeError): # Protect ourselves from any tqdm errors. In case of # unexpected tqdm behavior, just fall back to showing diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index 671bbdf297787..73125eaddd33f 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -4115,18 +4115,21 @@ def test_list_rows_empty_table(self): client._connection = _make_connection(response, response) # Table that has no schema because it's an empty table. - rows = tuple( - client.list_rows( - # Test with using a string for the table ID. - "{}.{}.{}".format( - self.TABLE_REF.project, - self.TABLE_REF.dataset_id, - self.TABLE_REF.table_id, - ), - selected_fields=[], - ) + rows = client.list_rows( + # Test with using a string for the table ID. + "{}.{}.{}".format( + self.TABLE_REF.project, + self.TABLE_REF.dataset_id, + self.TABLE_REF.table_id, + ), + selected_fields=[], ) - self.assertEqual(rows, ()) + + # When a table reference / string and selected_fields is provided, + # total_rows can't be populated until iteration starts. + self.assertIsNone(rows.total_rows) + self.assertEqual(tuple(rows), ()) + self.assertEqual(rows.total_rows, 0) def test_list_rows_query_params(self): from google.cloud.bigquery.table import Table, SchemaField @@ -4329,7 +4332,7 @@ def test_list_rows_with_missing_schema(self): conn.api_request.assert_called_once_with(method="GET", path=table_path) conn.api_request.reset_mock() - self.assertIsNone(row_iter.total_rows, msg=repr(table)) + self.assertEqual(row_iter.total_rows, 2, msg=repr(table)) rows = list(row_iter) conn.api_request.assert_called_once_with( diff --git a/bigquery/tests/unit/test_job.py b/bigquery/tests/unit/test_job.py index 833081ce066da..c199fb8c7d209 100644 --- a/bigquery/tests/unit/test_job.py +++ b/bigquery/tests/unit/test_job.py @@ -4008,21 +4008,37 @@ def test_estimated_bytes_processed(self): self.assertEqual(job.estimated_bytes_processed, est_bytes) def test_result(self): + from google.cloud.bigquery.table import RowIterator + query_resource = { "jobComplete": True, "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "totalRows": "2", } - connection = _make_connection(query_resource, query_resource) + tabledata_resource = { + "totalRows": "1", + "pageToken": None, + "rows": [{"f": [{"v": "abc"}]}], + } + connection = _make_connection(query_resource, tabledata_resource) client = _make_client(self.PROJECT, connection=connection) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) result = job.result() - self.assertEqual(list(result), []) + self.assertIsInstance(result, RowIterator) + self.assertEqual(result.total_rows, 2) + + rows = list(result) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].col1, "abc") + self.assertEqual(result.total_rows, 1) def test_result_w_empty_schema(self): + from google.cloud.bigquery.table import _EmptyRowIterator + # Destination table may have no schema for some DDL and DML queries. query_resource = { "jobComplete": True, @@ -4036,6 +4052,7 @@ def test_result_w_empty_schema(self): result = job.result() + self.assertIsInstance(result, _EmptyRowIterator) self.assertEqual(list(result), []) def test_result_invokes_begins(self): diff --git a/bigquery/tests/unit/test_table.py b/bigquery/tests/unit/test_table.py index 4500856ec2a48..44eae538b1bde 100644 --- a/bigquery/tests/unit/test_table.py +++ b/bigquery/tests/unit/test_table.py @@ -1282,51 +1282,85 @@ def test_row(self): class Test_EmptyRowIterator(unittest.TestCase): - @mock.patch("google.cloud.bigquery.table.pandas", new=None) - def test_to_dataframe_error_if_pandas_is_none(self): + def _make_one(self): from google.cloud.bigquery.table import _EmptyRowIterator - row_iterator = _EmptyRowIterator() + return _EmptyRowIterator() + + def test_total_rows_eq_zero(self): + row_iterator = self._make_one() + self.assertEqual(row_iterator.total_rows, 0) + + @mock.patch("google.cloud.bigquery.table.pandas", new=None) + def test_to_dataframe_error_if_pandas_is_none(self): + row_iterator = self._make_one() with self.assertRaises(ValueError): row_iterator.to_dataframe() @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe(self): - from google.cloud.bigquery.table import _EmptyRowIterator - - row_iterator = _EmptyRowIterator() + row_iterator = self._make_one() df = row_iterator.to_dataframe() self.assertIsInstance(df, pandas.DataFrame) self.assertEqual(len(df), 0) # verify the number of rows class TestRowIterator(unittest.TestCase): - def test_constructor(self): + def _make_one( + self, client=None, api_request=None, path=None, schema=None, **kwargs + ): from google.cloud.bigquery.table import RowIterator + + if client is None: + client = _mock_client() + + if api_request is None: + api_request = mock.sentinel.api_request + + if path is None: + path = "/foo" + + if schema is None: + schema = [] + + return RowIterator(client, api_request, path, schema, **kwargs) + + def test_constructor(self): from google.cloud.bigquery.table import _item_to_row from google.cloud.bigquery.table import _rows_page_start client = _mock_client() - api_request = mock.sentinel.api_request - path = "/foo" - schema = [] - iterator = RowIterator(client, api_request, path, schema) + path = "/some/path" + iterator = self._make_one(client=client, path=path) - self.assertFalse(iterator._started) + # Objects are set without copying. self.assertIs(iterator.client, client) - self.assertEqual(iterator.path, path) self.assertIs(iterator.item_to_value, _item_to_row) + self.assertIs(iterator._page_start, _rows_page_start) + # Properties have the expect value. + self.assertEqual(iterator.extra_params, {}) self.assertEqual(iterator._items_key, "rows") self.assertIsNone(iterator.max_results) - self.assertEqual(iterator.extra_params, {}) - self.assertIs(iterator._page_start, _rows_page_start) + self.assertEqual(iterator.path, path) + self.assertFalse(iterator._started) + self.assertIsNone(iterator.total_rows) # Changing attributes. self.assertEqual(iterator.page_number, 0) self.assertIsNone(iterator.next_page_token) self.assertEqual(iterator.num_results, 0) + def test_constructor_with_table(self): + from google.cloud.bigquery.table import Table + + table = Table("proj.dset.tbl") + table._properties["numRows"] = 100 + + iterator = self._make_one(table=table) + + self.assertIs(iterator._table, table) + self.assertEqual(iterator.total_rows, 100) + def test_iterate(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1339,7 +1373,7 @@ def test_iterate(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) self.assertEqual(row_iterator.num_results, 0) rows_iter = iter(row_iterator) @@ -1358,7 +1392,6 @@ def test_iterate(self): api_request.assert_called_once_with(method="GET", path=path, query_params={}) def test_page_size(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1372,7 +1405,7 @@ def test_page_size(self): path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator( + row_iterator = self._make_one( _mock_client(), api_request, path, schema, page_size=4 ) row_iterator._get_next_page_response() @@ -1385,7 +1418,6 @@ def test_page_size(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1400,7 +1432,7 @@ def test_to_dataframe(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1418,7 +1450,6 @@ def test_to_dataframe(self): def test_to_dataframe_progress_bar( self, tqdm_mock, tqdm_notebook_mock, tqdm_gui_mock ): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1441,7 +1472,7 @@ def test_to_dataframe_progress_bar( ) for progress_bar_type, progress_bar_mock in progress_bars: - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(progress_bar_type=progress_bar_type) progress_bar_mock.assert_called() @@ -1451,7 +1482,6 @@ def test_to_dataframe_progress_bar( @unittest.skipIf(pandas is None, "Requires `pandas`") @mock.patch("google.cloud.bigquery.table.tqdm", new=None) def test_to_dataframe_no_tqdm_no_progress_bar(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1466,7 +1496,7 @@ def test_to_dataframe_no_tqdm_no_progress_bar(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with warnings.catch_warnings(record=True) as warned: df = row_iterator.to_dataframe() @@ -1477,7 +1507,6 @@ def test_to_dataframe_no_tqdm_no_progress_bar(self): @unittest.skipIf(pandas is None, "Requires `pandas`") @mock.patch("google.cloud.bigquery.table.tqdm", new=None) def test_to_dataframe_no_tqdm(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1492,7 +1521,7 @@ def test_to_dataframe_no_tqdm(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with warnings.catch_warnings(record=True) as warned: df = row_iterator.to_dataframe(progress_bar_type="tqdm") @@ -1511,7 +1540,6 @@ def test_to_dataframe_no_tqdm(self): @mock.patch("tqdm.tqdm_notebook", new=None) # will raise TypeError on call @mock.patch("tqdm.tqdm", new=None) # will raise TypeError on call def test_to_dataframe_tqdm_error(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1528,14 +1556,13 @@ def test_to_dataframe_tqdm_error(self): for progress_bar_type in ("tqdm", "tqdm_notebook", "tqdm_gui"): api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(progress_bar_type=progress_bar_type) self.assertEqual(len(df), 4) # all should be well @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_w_empty_results(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1544,7 +1571,7 @@ def test_to_dataframe_w_empty_results(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": []}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1555,7 +1582,6 @@ def test_to_dataframe_w_empty_results(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_w_various_types_nullable(self): import datetime - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1575,7 +1601,7 @@ def test_to_dataframe_w_various_types_nullable(self): rows = [{"f": [{"v": field} for field in row]} for row in row_data] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe() @@ -1596,7 +1622,6 @@ def test_to_dataframe_w_various_types_nullable(self): @unittest.skipIf(pandas is None, "Requires `pandas`") def test_to_dataframe_column_dtypes(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1616,7 +1641,7 @@ def test_to_dataframe_column_dtypes(self): rows = [{"f": [{"v": field} for field in row]} for row in row_data] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) df = row_iterator.to_dataframe(dtypes={"km": "float16"}) @@ -1635,7 +1660,6 @@ def test_to_dataframe_column_dtypes(self): @mock.patch("google.cloud.bigquery.table.pandas", new=None) def test_to_dataframe_error_if_pandas_is_none(self): - from google.cloud.bigquery.table import RowIterator from google.cloud.bigquery.table import SchemaField schema = [ @@ -1648,7 +1672,7 @@ def test_to_dataframe_error_if_pandas_is_none(self): ] path = "/foo" api_request = mock.Mock(return_value={"rows": rows}) - row_iterator = RowIterator(_mock_client(), api_request, path, schema) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) with self.assertRaises(ValueError): row_iterator.to_dataframe()