Skip to content

Commit

Permalink
feat: Job and BatchPredictionJob classes (#79)
Browse files Browse the repository at this point in the history
* Create and move all constants to constants.py

* Fix tests after constants.py, drop unused vars

* Init Job and BatchPredictionJob class, unit tests

* Address all reviewer comments

* Update docstring to bigquery.table.RowIterator

* Get GCS/BQ clients to use same creds as uCAIP
  • Loading branch information
vinnysenthil authored Nov 24, 2020
1 parent 52de070 commit f2ccd1e
Show file tree
Hide file tree
Showing 8 changed files with 533 additions and 15 deletions.
10 changes: 9 additions & 1 deletion google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform.training_jobs import CustomTrainingJob
from google.cloud.aiplatform.jobs import BatchPredictionJob

"""
Usage:
Expand All @@ -31,4 +32,11 @@
"""
init = initializer.global_config.init

__all__ = ("gapic", "CustomTrainingJob", "Model", "Dataset", "Endpoint")
__all__ = (
"gapic",
"BatchPredictionJob",
"CustomTrainingJob",
"Model",
"Dataset",
"Endpoint",
)
31 changes: 31 additions & 0 deletions google/cloud/aiplatform/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1")
API_BASE_PATH = "aiplatform.googleapis.com"

# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
"jsonl",
"csv",
"tf-record",
"tf-record-gzip",
"bigquery",
"file-list",
)
BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery")
5 changes: 3 additions & 2 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import credentials as auth_credentials
from google.auth.exceptions import GoogleAuthError
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import constants


class _Config:
Expand Down Expand Up @@ -97,7 +98,7 @@ def project(self) -> str:
@property
def location(self) -> str:
"""Default location."""
return self._location or utils.DEFAULT_REGION
return self._location or constants.DEFAULT_REGION

@property
def experiment(self) -> Optional[str]:
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_client_options(
utils.validate_region(region)

return client_options.ClientOptions(
api_endpoint=f"{region}-{prediction}{utils.PROD_API_ENDPOINT}"
api_endpoint=f"{region}-{prediction}{constants.API_BASE_PATH}"
)

def common_location_path(
Expand Down
215 changes: 213 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,221 @@
# limitations under the License.
#

import abc
from typing import Iterable, Optional, Union

class Job:
from google.cloud import storage
from google.cloud import bigquery

from google.auth import credentials as auth_credentials

from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform_v1beta1.types import job_state
from google.cloud.aiplatform_v1beta1.services import job_service


class _Job(base.AiPlatformResourceNoun):
"""
Class that represents a general Job resource in AI Platform (Unified).
Cannot be directly instantiated.
Serves as base class to specific Job types, i.e. BatchPredictionJob or
DataLabelingJob to re-use shared functionality.
Subclasses requires one class attribute:
_getter_method (str): The name of JobServiceClient getter method for specific
Job type, i.e. 'get_custom_job' for CustomJob
"""

client_class = job_service.JobServiceClient
_is_client_prediction_client = False

@property
@abc.abstractclassmethod
def _getter_method(cls) -> str:
"""Name of getter method of Job subclass, i.e. 'get_custom_job' for CustomJob"""
pass

def _get_job(self, job_name: str):
"""Returns GAPIC service representation of Job subclass resource"""
return getattr(self.api_client, self._getter_method)(name=job_name)

def __init__(
self,
valid_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""
Retrives Job subclass resource by calling a subclass-specific getter method.
Args:
valid_job_name (str):
A validated, fully-qualified Job resource name. For example:
'projects/.../locations/.../batchPredictionJobs/456' or
'projects/.../locations/.../customJobs/789'
project: Optional[str] = None,
Optional project to retrieve Job subclass from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve Job subclass from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
super().__init__(project=project, location=location, credentials=credentials)
self._gca_resource = self._get_job(job_name=valid_job_name)

def status(self) -> job_state.JobState:
"""Fetch Job again and return the current JobState.
Returns:
state (job_state.JobState):
Enum that describes the state of a AI Platform job.
"""

# Fetch the Job again for most up-to-date job state
self._gca_resource = self._get_job(job_name=self._gca_resource.name)

return self._gca_resource.state


class BatchPredictionJob(_Job):

_getter_method = "get_batch_prediction_job"

def __init__(
self,
batch_prediction_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""
Retrieves a BatchPredictionJob resource and instantiates its representation.
Args:
batch_prediction_job_name (str):
Required. A fully-qualified BatchPredictionJob resource name or ID.
Example: "projects/.../locations/.../batchPredictionJobs/456" or
"456" when project and location are initialized or passed.
project: Optional[str] = None,
Optional project to retrieve BatchPredictionJob from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve BatchPredictionJob from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
valid_batch_prediction_job_name = utils.full_resource_name(
resource_name=batch_prediction_job_name,
resource_noun="batchPredictionJobs",
project=project,
location=location,
)

super().__init__(
valid_job_name=valid_batch_prediction_job_name,
project=project,
location=location,
credentials=credentials,
)

def iter_outputs(
self, bq_max_results: Optional[int] = 100
) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
"""Returns an Iterable object to traverse the output files, either a list
of GCS Blobs or a BigQuery RowIterator depending on the output config set
when the BatchPredictionJob was created.
Args:
bq_max_results: Optional[int] = 100
Limit on rows to retrieve from prediction table in BigQuery dataset.
Only used when retrieving predictions from a bigquery_destination_prefix.
Default is 100.
Returns:
Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
Either a list of GCS Blob objects within the prediction output
directory or an iterable BigQuery RowIterator with predictions.
Raises:
RuntimeError:
If BatchPredictionJob is in a JobState other than SUCCEEDED,
since outputs cannot be retrieved until the Job has finished.
NotImplementedError:
If BatchPredictionJob succeeded and output_info does not have a
GCS or BQ output provided.
"""

job_status = self.status()

if job_status != job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
f"current status: {job_status}"
)

output_info = self._gca_resource.output_info

# GCS Destination, return Blobs
if output_info.gcs_output_directory:

# Build a Storage Client using the same credentials as JobServiceClient
storage_client = storage.Client(
credentials=self.api_client._transport._credentials
)

blobs = storage_client.list_blobs(output_info.gcs_output_directory)
return blobs

# BigQuery Destination, return RowIterator
elif output_info.bigquery_output_dataset:

# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
credentials=self.api_client._transport._credentials
)

# Format from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset

if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]

# # Split project ID and BQ dataset ID
_, bq_dataset_id = bq_dataset.split(".", 1)

row_iterator = bq_client.list_rows(
table=f"{bq_dataset_id}.predictions", max_results=bq_max_results
)

return row_iterator

# Unknown Destination type
else:
raise NotImplementedError(
f"Unsupported batch prediction output location, here are details"
f"on your prediction output:\n{output_info}"
)


class CustomJob(_Job):
_getter_method = "get_custom_job"
pass


class DataLabelingJob(_Job):
_getter_method = "get_data_labeling_job"
pass


class BatchPredictionJob(Job):
class HyperparameterTuningJob(_Job):
_getter_method = "get_hyperparameter_tuning_job"
pass
13 changes: 7 additions & 6 deletions google/cloud/aiplatform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
client as prediction_client,
)

DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1")
PROD_API_ENDPOINT = "aiplatform.googleapis.com"
from google.cloud.aiplatform import constants


AiPlatformServiceClient = TypeVar(
"AiPlatformServiceClient",
Expand Down Expand Up @@ -218,12 +217,14 @@ def validate_region(region: str) -> bool:
ValueError: If region is not in supported regions.
"""
if not region:
raise ValueError(f"Please provide a region, select from {SUPPORTED_REGIONS}")
raise ValueError(
f"Please provide a region, select from {constants.SUPPORTED_REGIONS}"
)

region = region.lower()
if region not in SUPPORTED_REGIONS:
if region not in constants.SUPPORTED_REGIONS:
raise ValueError(
f"Unsupported region for AI Platform, select from {SUPPORTED_REGIONS}"
f"Unsupported region for AI Platform, select from {constants.SUPPORTED_REGIONS}"
)

return True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
install_requires=(
"google-api-core[grpc] >= 1.22.2, < 2.0.0dev",
"google-cloud-storage >= 1.32.0, < 2.0.0dev",
"google-cloud-bigquery >= 1.15.0, < 3.0.0dev",
"libcst >= 0.2.5",
"proto-plus >= 1.4.0",
),
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import google.auth
from google.auth import credentials
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import constants
from google.cloud.aiplatform_v1beta1.services.model_service.client import (
ModelServiceClient,
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_init_location_sets_location(self):
assert initializer.global_config.location == _TEST_LOCATION

def test_not_init_location_gets_default_location(self):
assert initializer.global_config.location == utils.DEFAULT_REGION
assert initializer.global_config.location == constants.DEFAULT_REGION

def test_init_location_with_invalid_location_raises(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_create_client_returns_client(self):
client = initializer.global_config.create_client(ModelServiceClient)
assert isinstance(client, ModelServiceClient)
assert (
client._transport._host == f"{_TEST_LOCATION}-{utils.PROD_API_ENDPOINT}:443"
client._transport._host == f"{_TEST_LOCATION}-{constants.API_BASE_PATH}:443"
)

def test_create_client_overrides(self):
Expand All @@ -109,7 +109,7 @@ def test_create_client_overrides(self):
assert isinstance(client, ModelServiceClient)
assert (
client._transport._host
== f"{_TEST_LOCATION_2}-prediction-{utils.PROD_API_ENDPOINT}:443"
== f"{_TEST_LOCATION_2}-prediction-{constants.API_BASE_PATH}:443"
)
assert client._transport._credentials == creds

Expand Down
Loading

0 comments on commit f2ccd1e

Please sign in to comment.