Skip to content

Commit

Permalink
Merge branch 'main' into job-rate-limit
Browse files Browse the repository at this point in the history
  • Loading branch information
kiraksi authored Jan 24, 2024
2 parents ffd6e38 + 6559dde commit 46f43a5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 20 deletions.
24 changes: 15 additions & 9 deletions google/cloud/bigquery/_job_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def do_query():
return future


def _validate_job_config(request_body: Dict[str, Any], invalid_key: str):
"""Catch common mistakes, such as passing in a *JobConfig object of the
wrong type.
"""
if invalid_key in request_body:
raise ValueError(f"got unexpected key {repr(invalid_key)} in job_config")


def _to_query_request(
job_config: Optional[job.QueryJobConfig] = None,
*,
Expand All @@ -179,17 +187,15 @@ def _to_query_request(
QueryRequest. If any configuration property is set that is not available in
jobs.query, it will result in a server-side error.
"""
request_body = {}
job_config_resource = job_config.to_api_repr() if job_config else {}
query_config_resource = job_config_resource.get("query", {})
request_body = copy.copy(job_config.to_api_repr()) if job_config else {}

request_body.update(query_config_resource)
_validate_job_config(request_body, job.CopyJob._JOB_TYPE)
_validate_job_config(request_body, job.ExtractJob._JOB_TYPE)
_validate_job_config(request_body, job.LoadJob._JOB_TYPE)

# These keys are top level in job resource and query resource.
if "labels" in job_config_resource:
request_body["labels"] = job_config_resource["labels"]
if "dryRun" in job_config_resource:
request_body["dryRun"] = job_config_resource["dryRun"]
# Move query.* properties to top-level.
query_config_resource = request_body.pop("query", {})
request_body.update(query_config_resource)

# Default to standard SQL.
request_body.setdefault("useLegacySql", False)
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/bigquery/magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,15 @@ def _create_dataset_if_necessary(client, dataset_id):
"Defaults to use tqdm_notebook. Install the ``tqdm`` package to use this feature."
),
)
@magic_arguments.argument(
"--location",
type=str,
default=None,
help=(
"Set the location to execute query."
"Defaults to location set in query setting in console."
),
)
def _cell_magic(line, query):
"""Underlying function for bigquery cell magic
Expand Down Expand Up @@ -551,6 +560,7 @@ def _cell_magic(line, query):
category=DeprecationWarning,
)
use_bqstorage_api = not args.use_rest_api
location = args.location

params = []
if params_option_value:
Expand Down Expand Up @@ -579,6 +589,7 @@ def _cell_magic(line, query):
default_query_job_config=context.default_query_job_config,
client_info=client_info.ClientInfo(user_agent=IPYTHON_USER_AGENT),
client_options=bigquery_client_options,
location=location,
)
if context._connection:
client._connection = context._connection
Expand Down
75 changes: 64 additions & 11 deletions tests/unit/test__job_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

from google.cloud.bigquery.client import Client
from google.cloud.bigquery import _job_helpers
from google.cloud.bigquery.job import copy_ as job_copy
from google.cloud.bigquery.job import extract as job_extract
from google.cloud.bigquery.job import load as job_load
from google.cloud.bigquery.job import query as job_query
from google.cloud.bigquery.query import ConnectionProperty, ScalarQueryParameter

Expand Down Expand Up @@ -57,9 +60,34 @@ def make_query_response(
@pytest.mark.parametrize(
("job_config", "expected"),
(
(None, make_query_request()),
(job_query.QueryJobConfig(), make_query_request()),
(
pytest.param(
None,
make_query_request(),
id="job_config=None-default-request",
),
pytest.param(
job_query.QueryJobConfig(),
make_query_request(),
id="job_config=QueryJobConfig()-default-request",
),
pytest.param(
job_query.QueryJobConfig.from_api_repr(
{
"unknownTopLevelProperty": "some-test-value",
"query": {
"unknownQueryProperty": "some-other-value",
},
},
),
make_query_request(
{
"unknownTopLevelProperty": "some-test-value",
"unknownQueryProperty": "some-other-value",
}
),
id="job_config-with-unknown-properties-includes-all-properties-in-request",
),
pytest.param(
job_query.QueryJobConfig(default_dataset="my-project.my_dataset"),
make_query_request(
{
Expand All @@ -69,17 +97,24 @@ def make_query_response(
}
}
),
id="job_config-with-default_dataset",
),
(job_query.QueryJobConfig(dry_run=True), make_query_request({"dryRun": True})),
(
pytest.param(
job_query.QueryJobConfig(dry_run=True),
make_query_request({"dryRun": True}),
id="job_config-with-dry_run",
),
pytest.param(
job_query.QueryJobConfig(use_query_cache=False),
make_query_request({"useQueryCache": False}),
id="job_config-with-use_query_cache",
),
(
pytest.param(
job_query.QueryJobConfig(use_legacy_sql=True),
make_query_request({"useLegacySql": True}),
id="job_config-with-use_legacy_sql",
),
(
pytest.param(
job_query.QueryJobConfig(
query_parameters=[
ScalarQueryParameter("named_param1", "STRING", "param-value"),
Expand All @@ -103,8 +138,9 @@ def make_query_response(
],
}
),
id="job_config-with-query_parameters-named",
),
(
pytest.param(
job_query.QueryJobConfig(
query_parameters=[
ScalarQueryParameter(None, "STRING", "param-value"),
Expand All @@ -126,8 +162,9 @@ def make_query_response(
],
}
),
id="job_config-with-query_parameters-positional",
),
(
pytest.param(
job_query.QueryJobConfig(
connection_properties=[
ConnectionProperty(key="time_zone", value="America/Chicago"),
Expand All @@ -142,14 +179,17 @@ def make_query_response(
]
}
),
id="job_config-with-connection_properties",
),
(
pytest.param(
job_query.QueryJobConfig(labels={"abc": "def"}),
make_query_request({"labels": {"abc": "def"}}),
id="job_config-with-labels",
),
(
pytest.param(
job_query.QueryJobConfig(maximum_bytes_billed=987654),
make_query_request({"maximumBytesBilled": "987654"}),
id="job_config-with-maximum_bytes_billed",
),
),
)
Expand All @@ -159,6 +199,19 @@ def test__to_query_request(job_config, expected):
assert result == expected


@pytest.mark.parametrize(
("job_config", "invalid_key"),
(
pytest.param(job_copy.CopyJobConfig(), "copy", id="copy"),
pytest.param(job_extract.ExtractJobConfig(), "extract", id="extract"),
pytest.param(job_load.LoadJobConfig(), "load", id="load"),
),
)
def test__to_query_request_raises_for_invalid_config(job_config, invalid_key):
with pytest.raises(ValueError, match=f"{repr(invalid_key)} in job_config"):
_job_helpers._to_query_request(job_config, query="SELECT 1")


def test__to_query_job_defaults():
mock_client = mock.create_autospec(Client)
response = make_query_response(
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,3 +2053,21 @@ def test_bigquery_magic_create_dataset_fails():
)

assert close_transports.called


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_with_location():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context.credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)

run_query_patch = mock.patch(
"google.cloud.bigquery.magics.magics._run_query", autospec=True
)
with run_query_patch as run_query_mock:
ip.run_cell_magic("bigquery", "--location=us-east1", "SELECT 17 AS num")

client_options_used = run_query_mock.call_args_list[0][0][0]
assert client_options_used.location == "us-east1"

0 comments on commit 46f43a5

Please sign in to comment.