-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support raw_predict for Endpoint #1620
Changes from 19 commits
b42a47f
9f850b2
5e32858
f6e7241
07796ae
5df7af8
3ed3e06
516771e
4586da2
f5c7dea
007595f
ffd5a75
4dd2472
3818bf6
ca30449
07cc2f0
6d6e173
f53ee8a
267e803
474654c
f2386c3
60ea5f9
ee319fd
eef9e6f
19c853e
bfac52c
d6e8ef7
2175a2c
748b0ed
bc80e81
f0f1e7a
be8d4de
1264f33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import re | ||
import shutil | ||
import tempfile | ||
import requests | ||
from typing import ( | ||
Any, | ||
Dict, | ||
|
@@ -35,9 +36,11 @@ | |
from google.api_core import operation | ||
from google.api_core import exceptions as api_exceptions | ||
from google.auth import credentials as auth_credentials | ||
from google.auth.transport.requests import AuthorizedSession | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import module instead of class. Because of the import name conflict maybe something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated import statement |
||
|
||
from google.cloud import aiplatform | ||
from google.cloud.aiplatform import base | ||
from google.cloud.aiplatform import constants | ||
from google.cloud.aiplatform import explain | ||
from google.cloud.aiplatform import initializer | ||
from google.cloud.aiplatform import jobs | ||
|
@@ -200,6 +203,8 @@ def __init__( | |
location=self.location, | ||
credentials=credentials, | ||
) | ||
self.authorized_session = None | ||
self.raw_predict_request_url = None | ||
Comment on lines
+208
to
+209
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these attributes be public? |
||
|
||
def _skipped_getter_call(self) -> bool: | ||
"""Check if GAPIC resource was populated by call to get/list API methods | ||
|
@@ -1481,6 +1486,7 @@ def predict( | |
instances: List, | ||
parameters: Optional[Dict] = None, | ||
timeout: Optional[float] = None, | ||
use_raw_predict: Optional[bool] = False, | ||
) -> Prediction: | ||
"""Make a prediction against this Endpoint. | ||
|
||
|
@@ -1505,29 +1511,67 @@ def predict( | |
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] | ||
``parameters_schema_uri``. | ||
timeout (float): Optional. The timeout for this request in seconds. | ||
use_raw_predict (bool): | ||
Optional. If set to True, the underlying prediction call will be made | ||
against Endpoint.raw_predict(). Currently, model version information will | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove "currently". I have checked with the service team and they mentioned that it's not on the roadmap to make model version info available for raw_predict due to performance reasons. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated doc strings |
||
not be available in the prediciton response using raw_predict. | ||
|
||
Returns: | ||
prediction (aiplatform.Prediction): | ||
Prediction with returned predictions and Model ID. | ||
""" | ||
self.wait() | ||
if use_raw_predict: | ||
raw_predict_response = self.raw_predict( | ||
body=json.dumps({"instances": instances, "parameters": parameters}), | ||
headers={"Content-Type": "application/json"}, | ||
) | ||
response_text = json.loads(raw_predict_response.text) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated variable name |
||
return Prediction( | ||
predictions=response_text["predictions"], | ||
deployed_model_id=raw_predict_response.headers[ | ||
"X-Vertex-AI-Deployed-Model-Id" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, if you can clarify @rosiezou, I'm guessing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also added constants ( |
||
], | ||
model_resource_name=raw_predict_response.headers["X-Vertex-AI-Model"], | ||
) | ||
else: | ||
prediction_response = self._prediction_client.predict( | ||
endpoint=self._gca_resource.name, | ||
instances=instances, | ||
parameters=parameters, | ||
timeout=timeout, | ||
) | ||
|
||
prediction_response = self._prediction_client.predict( | ||
endpoint=self._gca_resource.name, | ||
instances=instances, | ||
parameters=parameters, | ||
timeout=timeout, | ||
) | ||
return Prediction( | ||
predictions=[ | ||
json_format.MessageToDict(item) | ||
for item in prediction_response.predictions.pb | ||
], | ||
deployed_model_id=prediction_response.deployed_model_id, | ||
model_version_id=prediction_response.model_version_id, | ||
model_resource_name=prediction_response.model, | ||
) | ||
|
||
return Prediction( | ||
predictions=[ | ||
json_format.MessageToDict(item) | ||
for item in prediction_response.predictions.pb | ||
], | ||
deployed_model_id=prediction_response.deployed_model_id, | ||
model_version_id=prediction_response.model_version_id, | ||
model_resource_name=prediction_response.model, | ||
) | ||
def raw_predict( | ||
self, body: bytes = None, headers: Dict[str, str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going to allow |
||
) -> requests.models.Response: | ||
"""Makes a prediction request using arbitrary headers. | ||
|
||
Args: | ||
body (bytes): | ||
Required. The body of the prediction request in bytes. This must not exceed 1.5 mb per request. | ||
headers (Dict[str, str]): | ||
Required. The header of the request as a dictionary. There are no restrictions on the header. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a default value so this is not actually required, contrary to what the docstring says. On the other hand the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, neither of these args should have any defaults. I've updated the function signatures and doc strings. |
||
|
||
Returns: | ||
A requests.models.Response object containing the status code and prediction results. | ||
""" | ||
if not self.authorized_session: | ||
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES | ||
self.authorized_session = AuthorizedSession(self.credentials) | ||
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" | ||
|
||
return self.authorized_session.post(self.raw_predict_request_url, body, headers) | ||
|
||
def explain( | ||
self, | ||
|
@@ -2062,6 +2106,30 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict | |
deployed_model_id=self._gca_resource.deployed_models[0].id, | ||
) | ||
|
||
def raw_predict( | ||
self, body: bytes = None, headers: Dict[str, str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comments about the arguments' default values here. |
||
) -> requests.models.Response: | ||
"""Make a prediction request using arbitrary headers. | ||
This method must be called within the network the PrivateEndpoint is peered to. | ||
The function call will fail otherwise. To check, use `PrivateEndpoint.network`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know the specific error that would be raised? If so perhaps add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll test out a few scenarios, but most common one will be auth errors raised from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant specifically the part “the function call will fail” in the docstring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. I updated the doc strings in |
||
|
||
Args: | ||
body (bytes): | ||
Required. The body of the prediction request in bytes. This must not exceed 1.5 mb per request. | ||
headers (Dict[str, str]): | ||
Required. The header of the request as a dictionary. There are no restrictions on the header. | ||
|
||
Returns: | ||
A requests.models.Response object containing the status code and prediction results. | ||
""" | ||
self.wait() | ||
return self._http_request( | ||
method="POST", | ||
url=self.predict_http_uri, | ||
body=body, | ||
headers=headers, | ||
) | ||
|
||
def explain(self): | ||
raise NotImplementedError( | ||
f"{self.__class__.__name__} class does not support 'explain' as of now." | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import json | ||
|
||
from google.cloud import aiplatform | ||
|
||
from tests.system.aiplatform import e2e_base | ||
|
||
_PERMANENT_IRIS_ENDPOINT_ID = "4966625964059525120" | ||
_PREDICTION_INSTANCE = { | ||
"petal_length": "3.0", | ||
"petal_width": "3.0", | ||
"sepal_length": "3.0", | ||
"sepal_width": "3.0", | ||
} | ||
|
||
|
||
class TestModelInteractions(e2e_base.TestEndToEnd): | ||
_temp_prefix = "" | ||
endpoint = aiplatform.Endpoint(_PERMANENT_IRIS_ENDPOINT_ID) | ||
|
||
def test_prediction(self): | ||
# test basic predict | ||
prediction_response = self.endpoint.predict(instances=[_PREDICTION_INSTANCE]) | ||
assert len(prediction_response.predictions) == 1 | ||
|
||
# test predict(use_raw_predict = True) | ||
prediction_with_raw_predict = self.endpoint.predict( | ||
instances=[_PREDICTION_INSTANCE], use_raw_predict=True | ||
) | ||
assert ( | ||
prediction_with_raw_predict.deployed_model_id | ||
== prediction_response.deployed_model_id | ||
) | ||
assert ( | ||
prediction_with_raw_predict.model_resource_name | ||
== prediction_response.model_resource_name | ||
) | ||
|
||
# test raw_predict | ||
raw_prediction_response = self.endpoint.raw_predict( | ||
json.dumps({"instances": [_PREDICTION_INSTANCE]}), | ||
{"Content-Type": "application/json"}, | ||
) | ||
assert raw_prediction_response.status_code == 200 | ||
assert len(json.loads(raw_prediction_response.text).items()) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These attribute names start with underscores, but the attributes in
models.py
do not have leading underscores. Is this correct?