Skip to content

Commit

Permalink
Update automl_tables_predict.py with batch_predict_bq sample (#4142)
Browse files Browse the repository at this point in the history
Added a new method `batch_predict_bq` demonstrating running batch_prediction using BigQuery.
Added notes in comments about asynchronicity for `batch_predict` method.

The region `automl_tables_batch_predict_bq` will be used on cloud.google.com (currently both sections for GCS and BigQuery use the same sample code which is incorrect).

Fixes #4141

Note: It's a good idea to open an issue first for discussion.

- [x] Please **merge** this PR for me once it is approved.
  • Loading branch information
evil-shrike authored Jul 17, 2020
1 parent 0697602 commit ee7bb13
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tables/automl/automl_tables_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,49 @@ def predict(
# [END automl_tables_predict]


def batch_predict_bq(
project_id,
compute_region,
model_display_name,
bq_input_uri,
bq_output_uri,
):
"""Make a batch of predictions."""
# [START automl_tables_batch_predict_bq]
# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_display_name = 'MODEL_DISPLAY_NAME_HERE'
# bq_input_uri = 'bq://my-project.my-dataset.my-table'
# bq_output_uri = 'bq://my-project'

from google.cloud import automl_v1beta1 as automl

client = automl.TablesClient(project=project_id, region=compute_region)

# Query model
response = client.batch_predict(bigquery_input_uri=bq_input_uri,
bigquery_output_uri=bq_output_uri,
model_display_name=model_display_name)
print("Making batch prediction... ")
# `response` is a async operation descriptor,
# you can register a callback for the operation to complete via `add_done_callback`:
# def callback(operation_future):
# result = operation_future.result()
# response.add_done_callback(callback)
#
# or block the thread polling for the operation's results:
response.result()
# AutoML puts predictions in a newly generated dataset with a name by a mask "prediction_" + model_id + "_" + timestamp
# here's how to get the dataset name:
dataset_name = response.metadata.batch_predict_details.output_info.bigquery_output_dataset

print("Batch prediction complete.\nResults are in '{}' dataset.\n{}".format(
dataset_name, response.metadata))

# [END automl_tables_batch_predict_bq]


def batch_predict(
project_id,
compute_region,
Expand Down Expand Up @@ -108,7 +151,15 @@ def batch_predict(
model_display_name=model_display_name,
)
print("Making batch prediction... ")
# `response` is a async operation descriptor,
# you can register a callback for the operation to complete via `add_done_callback`:
# def callback(operation_future):
# result = operation_future.result()
# response.add_done_callback(callback)
#
# or block the thread polling for the operation's results:
response.result()

print("Batch prediction complete.\n{}".format(response.metadata))

# [END automl_tables_batch_predict]
Expand Down
12 changes: 12 additions & 0 deletions tables/automl/batch_predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
STATIC_MODEL = model_test.STATIC_MODEL
GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT)
GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT)
BQ_INPUT = "bq://{}.automl_test.bank_marketing".format(PROJECT)
BQ_OUTPUT = "bq://{}".format(PROJECT)


@pytest.mark.slow
Expand All @@ -42,6 +44,16 @@ def test_batch_predict(capsys):
assert "Batch prediction complete" in out


@pytest.mark.slow
def test_batch_predict_bq(capsys):
ensure_model_online()
automl_tables_predict.batch_predict_bq(
PROJECT, REGION, STATIC_MODEL, BQ_INPUT, BQ_OUTPUT
)
out, _ = capsys.readouterr()
assert "Batch prediction complete" in out


def ensure_model_online():
model = model_test.ensure_model_ready()
if model.deployment_state != enums.Model.DeploymentState.DEPLOYED:
Expand Down

0 comments on commit ee7bb13

Please sign in to comment.