Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Job and BatchPredictionJob classes #79

Merged
merged 9 commits into from
Nov 24, 2020
10 changes: 9 additions & 1 deletion google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform.training_jobs import CustomTrainingJob
from google.cloud.aiplatform.jobs import BatchPredictionJob

"""
Usage:
Expand All @@ -31,4 +32,11 @@
"""
init = initializer.global_config.init

__all__ = ("gapic", "CustomTrainingJob", "Model", "Dataset", "Endpoint")
__all__ = (
"gapic",
"BatchPredictionJob",
"CustomTrainingJob",
"Model",
"Dataset",
"Endpoint",
)
31 changes: 31 additions & 0 deletions google/cloud/aiplatform/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1")
PROD_API_ENDPOINT = "aiplatform.googleapis.com"
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
"jsonl",
"csv",
"tf-record",
"tf-record-gzip",
"bigquery",
"file-list",
)
BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery")
5 changes: 3 additions & 2 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import credentials as auth_credentials
from google.auth.exceptions import GoogleAuthError
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import constants


class _Config:
Expand Down Expand Up @@ -97,7 +98,7 @@ def project(self) -> str:
@property
def location(self) -> str:
"""Default location."""
return self._location or utils.DEFAULT_REGION
return self._location or constants.DEFAULT_REGION
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

@property
def experiment(self) -> Optional[str]:
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_client_options(
utils.validate_region(region)

return client_options.ClientOptions(
api_endpoint=f"{region}-{prediction}{utils.PROD_API_ENDPOINT}"
api_endpoint=f"{region}-{prediction}{constants.PROD_API_ENDPOINT}"
)

def common_location_path(
Expand Down
208 changes: 206 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,214 @@
# limitations under the License.
#

from abc import abstractclassmethod
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
from typing import Iterable, Optional, Union

class Job:
from google.cloud import storage
from google.cloud import bigquery

from google.auth import credentials as auth_credentials

from google.cloud.aiplatform import base
from google.cloud.aiplatform.gapic import JobState
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.aiplatform.gapic import JobServiceClient
from google.cloud.aiplatform.utils import full_resource_name
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved


class _Job(base.AiPlatformResourceNoun):
"""
Class that represents a general Job resource in AI Platform (Unified).
Cannot be directly instantiated.

Serves as base class to specific Job types, i.e. BatchPredictionJob or
DataLabelingJob to re-use shared functionality.

Subclasses requires one class attribute:

getter_method (str): The name of JobServiceClient getter method for specific
Job type, i.e. 'get_custom_job' for CustomJob
"""

client_class = JobServiceClient
_is_client_prediction_client = False

@property
@abstractclassmethod
def getter_method(cls) -> str:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""Name of getter method of Job subclass, i.e. 'get_custom_job' for CustomJob"""
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
pass

def __init__(
self,
valid_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""
Retrives Job subclass resource by calling a subclass-specific getter method.

Args:
valid_job_name (str):
A validated, fully-qualified Job resource name. For example:
'projects/.../locations/.../batchPredictionJobs/456' or
'projects/.../locations/.../customJobs/789'
project: Optional[str] = None,
Optional project to retrieve Job subclass from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve Job subclass from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
super().__init__(project=project, location=location, credentials=credentials)
self.job_subclass_getter_method = getattr(self.api_client, self.getter_method)
self._gca_resource = self.job_subclass_getter_method(name=valid_job_name)

vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
def status(self) -> JobState:
"""Fetch Job again and return the current JobState.

Returns:
state (JobState):
Enum that describes the state of a AI Platform job.
"""

# Fetch the Job again for most up-to-date job state
self._gca_resource = self.job_subclass_getter_method(
name=self._gca_resource.name
)

return self._gca_resource.state


class BatchPredictionJob(_Job):

getter_method = "get_batch_prediction_job"

def __init__(
self,
batch_prediction_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""
Retrieves a BatchPredictionJob resource and instantiates its representation.

Args:
batch_prediction_job_name (str):
Required. A fully-qualified BatchPredictionJob resource name or ID.
Example: "projects/.../locations/.../batchPredictionJobs/456" or
"456" when project and location are initialized or passed.
project: Optional[str] = None,
Optional project to retrieve BatchPredictionJob from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve BatchPredictionJob from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
valid_batch_prediction_job_name = full_resource_name(
resource_name=batch_prediction_job_name,
resource_noun="batchPredictionJobs",
project=project,
location=location,
)

super().__init__(
valid_job_name=valid_batch_prediction_job_name,
project=project,
location=location,
credentials=credentials,
)

def iter_outputs(
self, bq_query_limit: Optional[int] = 100
) -> Iterable[Union[storage.Blob, bigquery.job.QueryJob]]:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""Returns an Iterable object to traverse the output files, either a list
of GCS Blobs or a BigQuery QueryJob depending on the output config set
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
when the BatchPredictionJob was created.

Args:
bq_query_limit: Optional[int] = 100
Limit on rows to select from prediction table in BigQuery dataset.
Only used when retrieving predictions from a bigquery_destination_prefix.
Default is 100.

Returns:
Iterable[Union[storage.Blob, bigquery.job.QueryJob]]:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
Either a list of GCS Blob objects within the prediction output
directory or an iterable QueryJob with predictions.

Raises:
RuntimeError:
If BatchPredictionJob is in a JobState other than SUCCEEDED,
since outputs cannot be retrieved until the Job has finished.
NotImplementedError:
If BatchPredictionJob succeeded and output_info does not have a
GCS or BQ output provided.
"""

job_status = self.status()

if job_status != JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
f"current status: {job_status}"
)

output_info = self._gca_resource.output_info

# GCS Destination, return Blobs
if output_info.gcs_output_directory not in ("", None):
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
storage_client = storage.Client()
blobs = storage_client.list_blobs(output_info.gcs_output_directory)
return blobs

# BigQuery Destination, return QueryJob
elif output_info.bigquery_output_dataset not in ("", None):
bq_client = bigquery.Client()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit concerned about this. Ideally you'd re-use the credentials from uCAIP client.

Also, there's a risk of leaking sockets when you create clients on-the-fly. Not as big a deal for REST clients, but definitely a concern for gRPC clients. googleapis/google-cloud-python#9790 googleapis/google-cloud-python#9457

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bq_client = bigquery.Client()
bq_client = bigquery.Client(
credentials=self.api_client._transport._credentials
)

^ This change would build a BigQuery Client using the same credentials as uCAIP's JobServiceClient.

In regards to the leaking sockets, would the solution referenced in that issue work? See below

# Close sockets opened by BQ Client
bq_client._http._auth_request.session.close()
bq_client._http.close()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change for credentials LGTM. (Storage should get similar treatment).

It's a little trickier in our case, because we want the client to live for the lifetime of the RowIterator.

Unless you want to convert to full list of rows / pandas dataframe before returning? In which case all the API requests would be made here and we could close the client when done (FWIW, the client does have a close function in BQ. https://googleapis.dev/python/bigquery/latest/generated/google.cloud.bigquery.client.Client.html#google.cloud.bigquery.client.Client.close)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the credentials change on both BQ and Storage.

Re: closing connections - this is indeed tricky since the method is meant to return an iterator. However your comment made realize a larger issue of us instantiating a GAPIC client for every instance of a high-level SDK object. I'm capturing this in b/174111905.

Will merge this blocking PR for now, thanks for calling this issue out!

bigquery.Client

# Format from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset

if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]
if bq_dataset.endswith(("/", ".")):
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
bq_dataset = bq_dataset[:-1]

# # Split project ID and BQ dataset ID
_, bq_dataset_id = bq_dataset.split(".", 1)

query_limit = f"LIMIT {bq_query_limit}" if bq_query_limit else ""
query = f"SELECT * FROM {bq_dataset_id}.predictions {query_limit}"
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
query_job = bq_client.query(query)

return query_job

# Unknown Destination type
else:
raise NotImplementedError(
f"Unsupported batch prediction output location, here are details"
f"on your prediction output:\n{output_info}"
)


class CustomJob(_Job):
getter_method = "get_custom_job"
pass


class DataLabelingJob(_Job):
getter_method = "get_data_labeling_job"
pass


class BatchPredictionJob(Job):
class HyperparameterTuningJob(_Job):
getter_method = "get_hyperparameter_tuning_job"
pass
5 changes: 2 additions & 3 deletions google/cloud/aiplatform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
client as prediction_client,
)

DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1")
PROD_API_ENDPOINT = "aiplatform.googleapis.com"
from google.cloud.aiplatform.constants import SUPPORTED_REGIONS
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved


AiPlatformServiceClient = TypeVar(
"AiPlatformServiceClient",
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
install_requires=(
"google-api-core[grpc] >= 1.22.2, < 2.0.0dev",
"google-cloud-storage >= 1.32.0, < 2.0.0dev",
"google-cloud-bigquery >= 2.3.1, < 3.0.0dev",
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"libcst >= 0.2.5",
"proto-plus >= 1.4.0",
),
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import google.auth
from google.auth import credentials
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import constants
from google.cloud.aiplatform_v1beta1.services.model_service.client import (
ModelServiceClient,
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_init_location_sets_location(self):
assert initializer.global_config.location == _TEST_LOCATION

def test_not_init_location_gets_default_location(self):
assert initializer.global_config.location == utils.DEFAULT_REGION
assert initializer.global_config.location == constants.DEFAULT_REGION

def test_init_location_with_invalid_location_raises(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -94,7 +94,8 @@ def test_create_client_returns_client(self):
client = initializer.global_config.create_client(ModelServiceClient)
assert isinstance(client, ModelServiceClient)
assert (
client._transport._host == f"{_TEST_LOCATION}-{utils.PROD_API_ENDPOINT}:443"
client._transport._host
== f"{_TEST_LOCATION}-{constants.PROD_API_ENDPOINT}:443"
)

def test_create_client_overrides(self):
Expand All @@ -109,7 +110,7 @@ def test_create_client_overrides(self):
assert isinstance(client, ModelServiceClient)
assert (
client._transport._host
== f"{_TEST_LOCATION_2}-prediction-{utils.PROD_API_ENDPOINT}:443"
== f"{_TEST_LOCATION_2}-prediction-{constants.PROD_API_ENDPOINT}:443"
)
assert client._transport._credentials == creds

Expand Down
Loading