Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support raw_predict for Endpoint #1620

Merged
merged 33 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b42a47f
feat: support raw_predict for Endpoints
rosiezou Aug 30, 2022
9f850b2
formatting
rosiezou Aug 30, 2022
5e32858
Merge branch 'main' into raw-predict
rosiezou Aug 30, 2022
f6e7241
fixed broken unit test
rosiezou Aug 31, 2022
07796ae
Merge branch 'main' into raw-predict
rosiezou Aug 31, 2022
5df7af8
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 31, 2022
3ed3e06
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 31, 2022
516771e
Merge branch 'raw-predict' of https://github.com/googleapis/python-ai…
gcf-owl-bot[bot] Aug 31, 2022
4586da2
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 31, 2022
f5c7dea
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Aug 31, 2022
007595f
Merge branch 'raw-predict' of https://github.com/googleapis/python-ai…
gcf-owl-bot[bot] Aug 31, 2022
ffd5a75
Merge branch 'main' into raw-predict
rosiezou Aug 31, 2022
4dd2472
remove commented out code blocks
rosiezou Aug 31, 2022
3818bf6
removing debug print statements
rosiezou Aug 31, 2022
ca30449
removing extra prints
rosiezou Aug 31, 2022
07cc2f0
update copyright header date
rosiezou Aug 31, 2022
6d6e173
Merge branch 'main' into raw-predict
rosiezou Aug 31, 2022
f53ee8a
removing automatically added python 3.6 support for kokoro
rosiezou Aug 31, 2022
267e803
Merge branch 'main' into raw-predict
nayaknishant Aug 31, 2022
474654c
Merge branch 'main' into raw-predict
rosiezou Sep 1, 2022
f2386c3
Merge branch 'main' into raw-predict
rosiezou Sep 1, 2022
60ea5f9
addressed PR comments
rosiezou Sep 1, 2022
ee319fd
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 1, 2022
eef9e6f
adding unit test
rosiezou Sep 2, 2022
19c853e
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 2, 2022
bfac52c
removed unused import
rosiezou Sep 2, 2022
d6e8ef7
Merge branch 'main' into raw-predict
rosiezou Sep 2, 2022
2175a2c
renamed raw predict constants
rosiezou Sep 2, 2022
748b0ed
modified error messages
rosiezou Sep 2, 2022
bc80e81
added doc strings
rosiezou Sep 2, 2022
f0f1e7a
fixed typo in doc strings
rosiezou Sep 2, 2022
be8d4de
removed extra space
rosiezou Sep 2, 2022
1264f33
Merge branch 'main' into raw-predict
rosiezou Sep 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _are_futures_done(self) -> bool:
return self.__latest_future is None

def wait(self):
"""Helper method to that blocks until all futures are complete."""
"""Helper method that blocks until all futures are complete."""
future = self.__latest_future
if future:
futures.wait([future], return_when=futures.FIRST_EXCEPTION)
Expand Down Expand Up @@ -974,7 +974,11 @@ def _sync_object_with_future_result(
"_gca_resource",
"credentials",
]
optional_sync_attributes = ["_prediction_client"]
optional_sync_attributes = [
"_prediction_client",
"_authorized_session",
"_raw_predict_request_url",
Comment on lines +979 to +980
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These attribute names start with underscores, but the attributes in models.py do not have leading underscores. Is this correct?

]

for attribute in sync_attributes:
setattr(self, attribute, getattr(result, attribute))
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/constants/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@
# that is being used for usage metrics tracking purposes.
# For more details on go/oneplatform-api-analytics
USER_AGENT_SDK_COMMAND = ""

# Needed for Endpoint.raw_predict
DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
dizcology marked this conversation as resolved.
Show resolved Hide resolved
106 changes: 90 additions & 16 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import re
import shutil
import tempfile
import requests
from typing import (
Any,
Dict,
Expand All @@ -35,9 +36,11 @@
from google.api_core import operation
from google.api_core import exceptions as api_exceptions
from google.auth import credentials as auth_credentials
from google.auth.transport import requests as google_auth_requests

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import explain
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import jobs
Expand Down Expand Up @@ -69,6 +72,8 @@
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
_SUCCESSFUL_HTTP_RESPONSE = 300
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
_RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -200,6 +205,8 @@ def __init__(
location=self.location,
credentials=credentials,
)
self.authorized_session = None
self.raw_predict_request_url = None
Comment on lines +208 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these attributes be public?


def _skipped_getter_call(self) -> bool:
"""Check if GAPIC resource was populated by call to get/list API methods
Expand Down Expand Up @@ -1481,6 +1488,7 @@ def predict(
instances: List,
parameters: Optional[Dict] = None,
timeout: Optional[float] = None,
use_raw_predict: Optional[bool] = False,
) -> Prediction:
"""Make a prediction against this Endpoint.

Expand All @@ -1505,29 +1513,71 @@ def predict(
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
timeout (float): Optional. The timeout for this request in seconds.
use_raw_predict (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
against Endpoint.raw_predict(). Note that model version information will
not be available in the prediciton response using raw_predict.

Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions and Model ID.
"""
self.wait()
if use_raw_predict:
raw_predict_response = self.raw_predict(
body=json.dumps({"instances": instances, "parameters": parameters}),
headers={"Content-Type": "application/json"},
)
json_response = json.loads(raw_predict_response.text)
return Prediction(
predictions=json_response["predictions"],
deployed_model_id=raw_predict_response.headers[
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY
],
model_resource_name=raw_predict_response.headers[
_RAW_PREDICT_MODEL_RESOURCE_KEY
],
)
else:
prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)

prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)
return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in prediction_response.predictions.pb
],
deployed_model_id=prediction_response.deployed_model_id,
model_version_id=prediction_response.model_version_id,
model_resource_name=prediction_response.model,
)

return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in prediction_response.predictions.pb
],
deployed_model_id=prediction_response.deployed_model_id,
model_version_id=prediction_response.model_version_id,
model_resource_name=prediction_response.model,
)
def raw_predict(
self, body: bytes, headers: Dict[str, str]
) -> requests.models.Response:
"""Makes a prediction request using arbitrary headers.

Args:
body (bytes):
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no restrictions on the header.

Returns:
A requests.models.Response object containing the status code and prediction results.
"""
if not self.authorized_session:
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
self.authorized_session = google_auth_requests.AuthorizedSession(
self.credentials
)
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"

return self.authorized_session.post(self.raw_predict_request_url, body, headers)

def explain(
self,
Expand Down Expand Up @@ -2004,7 +2054,7 @@ def _http_request(
def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
"""Make a prediction against this PrivateEndpoint using a HTTP request.
This method must be called within the network the PrivateEndpoint is peered to.
The predict() call will fail otherwise. To check, use `PrivateEndpoint.network`.
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.

Example usage:
response = my_private_endpoint.predict(instances=[...])
Expand Down Expand Up @@ -2062,6 +2112,30 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
deployed_model_id=self._gca_resource.deployed_models[0].id,
)

def raw_predict(
self, body: bytes, headers: Dict[str, str]
) -> requests.models.Response:
"""Make a prediction request using arbitrary headers.
This method must be called within the network the PrivateEndpoint is peered to.
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.

Args:
body (bytes):
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no restrictions on the header.

Returns:
A requests.models.Response object containing the status code and prediction results.
"""
self.wait()
return self._http_request(
method="POST",
url=self.predict_http_uri,
body=body,
headers=headers,
)

def explain(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'explain' as of now."
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@
"uvicorn >= 0.16.0",
]

private_endpoints_extra_require = [
"urllib3 >=1.21.1, <1.27",
]
endpoint_extra_require = ["requests >= 2.28.1"]

private_endpoints_extra_require = ["urllib3 >=1.21.1, <1.27", "requests >= 2.28.1"]
full_extra_require = list(
set(
tensorboard_extra_require
Expand All @@ -92,6 +92,7 @@
+ featurestore_extra_require
+ pipelines_extra_require
+ datasets_extra_require
+ endpoint_extra_require
+ vizier_extra_require
+ prediction_extra_require
+ private_endpoints_extra_require
Expand Down Expand Up @@ -136,6 +137,7 @@
"google-cloud-resource-manager >= 1.3.3, < 3.0.0dev",
),
extras_require={
"endpoint": endpoint_extra_require,
"full": full_extra_require,
"metadata": metadata_extra_require,
"tensorboard": tensorboard_extra_require,
Expand Down
61 changes: 61 additions & 0 deletions tests/system/aiplatform/test_model_interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json

from google.cloud import aiplatform

from tests.system.aiplatform import e2e_base

_PERMANENT_IRIS_ENDPOINT_ID = "4966625964059525120"
_PREDICTION_INSTANCE = {
"petal_length": "3.0",
"petal_width": "3.0",
"sepal_length": "3.0",
"sepal_width": "3.0",
}


class TestModelInteractions(e2e_base.TestEndToEnd):
_temp_prefix = ""
endpoint = aiplatform.Endpoint(_PERMANENT_IRIS_ENDPOINT_ID)

def test_prediction(self):
# test basic predict
prediction_response = self.endpoint.predict(instances=[_PREDICTION_INSTANCE])
assert len(prediction_response.predictions) == 1

# test predict(use_raw_predict = True)
prediction_with_raw_predict = self.endpoint.predict(
instances=[_PREDICTION_INSTANCE], use_raw_predict=True
)
assert (
prediction_with_raw_predict.deployed_model_id
== prediction_response.deployed_model_id
)
assert (
prediction_with_raw_predict.model_resource_name
== prediction_response.model_resource_name
)

# test raw_predict
raw_prediction_response = self.endpoint.raw_predict(
json.dumps({"instances": [_PREDICTION_INSTANCE]}),
{"Content-Type": "application/json"},
)
assert raw_prediction_response.status_code == 200
assert len(json.loads(raw_prediction_response.text)) == 1
5 changes: 2 additions & 3 deletions tests/system/aiplatform/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@pytest.mark.usefixtures("delete_staging_bucket", "tear_down_resources")
class TestModel(e2e_base.TestEndToEnd):
class TestModelUploadAndUpdate(e2e_base.TestEndToEnd):

_temp_prefix = "temp_vertex_sdk_e2e_model_upload_test"

Expand Down Expand Up @@ -65,9 +65,8 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
# See https://github.com/googleapis/python-aiplatform/issues/773
endpoint = model.deploy(machine_type="n1-standard-2")
shared_state["resources"].append(endpoint)
predict_response = endpoint.predict(instances=[[0, 0, 0]])
assert len(predict_response.predictions) == 1

# test model update
model = model.update(
display_name="new_name",
description="new_description",
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from concurrent import futures
import pathlib
import pytest
import requests
from unittest import mock
from unittest.mock import patch

Expand All @@ -31,6 +32,7 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import constants

from google.cloud.aiplatform.compat.services import (
endpoint_service_client,
Expand Down Expand Up @@ -309,6 +311,10 @@

_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"

_TEST_RAW_PREDICT_URL = f"https://{_TEST_LOCATION}-{constants.base.API_BASE_PATH}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict"
_TEST_RAW_PREDICT_DATA = b""
_TEST_RAW_PREDICT_HEADER = {"Content-Type": "application/json"}


@pytest.fixture
def mock_model():
Expand All @@ -329,6 +335,22 @@ def update_model_mock(mock_model):
yield mock


@pytest.fixture
def authorized_session_mock():
with patch(
"google.auth.transport.requests.AuthorizedSession"
) as MockAuthorizedSession:
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
yield mock_auth_session


@pytest.fixture
def raw_predict_mock(authorized_session_mock):
with patch.object(authorized_session_mock, "post") as mock_post:
mock_post.return_value = requests.models.Response()
yield mock_post


@pytest.fixture
def get_endpoint_mock():
with mock.patch.object(
Expand Down Expand Up @@ -2707,3 +2729,16 @@ def test_list(self, list_models_mock):

assert listed_model.versioning_registry
assert listed_model._revisioned_resource_id_validator

@pytest.mark.usefixtures(
"get_endpoint_mock",
"get_model_mock",
"create_endpoint_mock",
"raw_predict_mock",
)
def test_raw_predict(self, raw_predict_mock):
test_endpoint = models.Endpoint(_TEST_ID)
test_endpoint.raw_predict(_TEST_RAW_PREDICT_DATA, _TEST_RAW_PREDICT_HEADER)
raw_predict_mock.assert_called_once_with(
_TEST_RAW_PREDICT_URL, _TEST_RAW_PREDICT_DATA, _TEST_RAW_PREDICT_HEADER
)