From 0280a945778f987e2f5b069dbbb91f8780ee4e8e Mon Sep 17 00:00:00 2001 From: Gurov Ilya Date: Thu, 23 Jan 2020 19:54:09 +0300 Subject: [PATCH] feat(bigquery): check `rows` argument type in `insert_rows()` (#10174) * feat(bigquery): check rows arg type in insert_rows() * add class marking * add Iterator into if statement to pass islices * add Iterator into if statement to pass islices * black reformat --- bigquery/google/cloud/bigquery/client.py | 16 ++++++++++++---- bigquery/tests/unit/test_client.py | 15 +++++++++++++-- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/bigquery/google/cloud/bigquery/client.py b/bigquery/google/cloud/bigquery/client.py index d37d8ac19e21..83e6bf8045ed 100644 --- a/bigquery/google/cloud/bigquery/client.py +++ b/bigquery/google/cloud/bigquery/client.py @@ -1220,7 +1220,7 @@ def delete_table( raise def _get_query_results( - self, job_id, retry, project=None, timeout_ms=None, location=None, timeout=None, + self, job_id, retry, project=None, timeout_ms=None, location=None, timeout=None ): """Get the query results object for a query job. @@ -2355,7 +2355,7 @@ def insert_rows(self, table, rows, selected_fields=None, **kwargs): str, \ ]): The destination table for the row data, or a reference to it. - rows (Union[Sequence[Tuple], Sequence[dict]]): + rows (Union[Sequence[Tuple], Sequence[Dict]]): Row data to be inserted. If a list of tuples is given, each tuple should contain data for each schema field on the current table and in the same order as the schema fields. If @@ -2376,8 +2376,11 @@ def insert_rows(self, table, rows, selected_fields=None, **kwargs): the mappings describing one or more problems with the row. Raises: - ValueError: if table's schema is not set + ValueError: if table's schema is not set or `rows` is not a `Sequence`. """ + if not isinstance(rows, (collections_abc.Sequence, collections_abc.Iterator)): + raise TypeError("rows argument should be a sequence of dicts or tuples") + table = _table_arg_to_table(table, default_project=self.project) if not isinstance(table, Table): @@ -2505,8 +2508,13 @@ def insert_rows_json( One mapping per row with insert errors: the "index" key identifies the row, and the "errors" key contains a list of the mappings describing one or more problems with the row. + + Raises: + TypeError: if `json_rows` is not a `Sequence`. """ - if not isinstance(json_rows, collections_abc.Sequence): + if not isinstance( + json_rows, (collections_abc.Sequence, collections_abc.Iterator) + ): raise TypeError("json_rows argument should be a sequence of dicts") # Convert table to just a reference because unlike insert_rows, # insert_rows_json doesn't need the table schema. It's not doing any diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index cce4bc532074..6b40d8a020a4 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -5048,6 +5048,7 @@ def _row_data(row): ) def test_insert_rows_errors(self): + from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.table import Table ROWS = [ @@ -5058,6 +5059,7 @@ def test_insert_rows_errors(self): ] creds = _make_credentials() http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) # table ref with no selected fields @@ -5068,10 +5070,19 @@ def test_insert_rows_errors(self): with self.assertRaises(ValueError): client.insert_rows(Table(self.TABLE_REF), ROWS) - # neither Table nor tableReference + # neither Table nor TableReference with self.assertRaises(TypeError): client.insert_rows(1, ROWS) + schema = [ + SchemaField("full_name", "STRING", mode="REQUIRED"), + ] + table = Table(self.TABLE_REF, schema=schema) + + # rows is just a dict + with self.assertRaises(TypeError): + client.insert_rows(table, {"full_name": "value"}) + def test_insert_rows_w_numeric(self): from google.cloud.bigquery import table from google.cloud.bigquery.schema import SchemaField @@ -5853,7 +5864,7 @@ def test_list_rows_error(self): http = object() client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) - # neither Table nor tableReference + # neither Table nor TableReference with self.assertRaises(TypeError): client.list_rows(1)