Skip to content

Commit

Permalink
Make total_rows available on RowIterator before iteration
Browse files Browse the repository at this point in the history
After running a query, the total number of rows is available from the
call to the getQueryResults API. This commit plumbs the total rows
through to the faux Table created in QueryJob.results and then on
through to the RowIterator created by list_rows.

Also, call get_table in list_rows... TODO: split that out to a separate
PR
  • Loading branch information
tswast committed Mar 30, 2019
1 parent 1c2cee7 commit 34704c3
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 56 deletions.
3 changes: 1 addition & 2 deletions bigquery/docs/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,7 @@ def test_client_query_total_rows(client, capsys):
location="US",
) # API request - starts the query

results = query_job.result() # Waits for query to complete.
next(iter(results)) # Fetch the first page of results, which contains total_rows.
results = query_job.result() # Wait for query to complete.
print("Got {} rows.".format(results.total_rows))
# [END bigquery_query_total_rows]

Expand Down
1 change: 1 addition & 0 deletions bigquery/google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,7 @@ def result(self, timeout=None, retry=DEFAULT_RETRY):
schema = self._query_results.schema
dest_table_ref = self.destination
dest_table = Table(dest_table_ref, schema=schema)
dest_table._properties["numRows"] = self._query_results.total_rows
return self._client.list_rows(dest_table, retry=retry)

def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
Expand Down
8 changes: 5 additions & 3 deletions bigquery/google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,11 @@ def __init__(
)
self._schema = schema
self._field_to_index = _helpers._field_to_index_mapping(schema)

self._total_rows = None
if table is not None and hasattr(table, "num_rows"):
self._total_rows = table.num_rows

self._page_size = page_size
self._table = table
self._selected_fields = selected_fields
Expand Down Expand Up @@ -1422,9 +1426,7 @@ def _get_progress_bar(self, progress_bar_type):
desc=description, total=self.total_rows, unit=unit
)
elif progress_bar_type == "tqdm_gui":
return tqdm.tqdm_gui(
desc=description, total=self.total_rows, unit=unit
)
return tqdm.tqdm_gui(desc=description, total=self.total_rows, unit=unit)
except (KeyError, TypeError):
# Protect ourselves from any tqdm errors. In case of
# unexpected tqdm behavior, just fall back to showing
Expand Down
27 changes: 15 additions & 12 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4115,18 +4115,21 @@ def test_list_rows_empty_table(self):
client._connection = _make_connection(response, response)

# Table that has no schema because it's an empty table.
rows = tuple(
client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
rows = client.list_rows(
# Test with using a string for the table ID.
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
selected_fields=[],
)
self.assertEqual(rows, ())

# When a table reference / string and selected_fields is provided,
# total_rows can't be populated until iteration starts.
self.assertIsNone(rows.total_rows)
self.assertEqual(tuple(rows), ())
self.assertEqual(rows.total_rows, 0)

def test_list_rows_query_params(self):
from google.cloud.bigquery.table import Table, SchemaField
Expand Down Expand Up @@ -4329,7 +4332,7 @@ def test_list_rows_with_missing_schema(self):

conn.api_request.assert_called_once_with(method="GET", path=table_path)
conn.api_request.reset_mock()
self.assertIsNone(row_iter.total_rows, msg=repr(table))
self.assertEqual(row_iter.total_rows, 2, msg=repr(table))

rows = list(row_iter)
conn.api_request.assert_called_once_with(
Expand Down
21 changes: 19 additions & 2 deletions bigquery/tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,21 +4008,37 @@ def test_estimated_bytes_processed(self):
self.assertEqual(job.estimated_bytes_processed, est_bytes)

def test_result(self):
from google.cloud.bigquery.table import RowIterator

query_resource = {
"jobComplete": True,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "2",
}
connection = _make_connection(query_resource, query_resource)
tabledata_resource = {
"totalRows": "1",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
}
connection = _make_connection(query_resource, tabledata_resource)
client = _make_client(self.PROJECT, connection=connection)
resource = self._make_resource(ended=True)
job = self._get_target_class().from_api_repr(resource, client)

result = job.result()

self.assertEqual(list(result), [])
self.assertIsInstance(result, RowIterator)
self.assertEqual(result.total_rows, 2)

rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(result.total_rows, 1)

def test_result_w_empty_schema(self):
from google.cloud.bigquery.table import _EmptyRowIterator

# Destination table may have no schema for some DDL and DML queries.
query_resource = {
"jobComplete": True,
Expand All @@ -4036,6 +4052,7 @@ def test_result_w_empty_schema(self):

result = job.result()

self.assertIsInstance(result, _EmptyRowIterator)
self.assertEqual(list(result), [])

def test_result_invokes_begins(self):
Expand Down
Loading

0 comments on commit 34704c3

Please sign in to comment.