-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support raw_predict for Endpoint #1620
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is test_model_interactions.py
testing permanent models? If it is I think we should keep it as two separate files (as you have it), but if it's using temp models created in test_model_upload.py
wouldn't it make more sense to have a predict()
call and then a raw_predict()
call after in the same file (maybe change the name to a more broad test_model.py
?
google/cloud/aiplatform/models.py
Outdated
return Prediction( | ||
predictions=response_text["predictions"], | ||
deployed_model_id=raw_predict_response.headers[ | ||
"X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, if you can clarify @rosiezou, I'm guessing X-Vertex-AI-Deployed-Model-Id
is populated when the response is returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also added constants (_RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
and _RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"
) to replace the hardcoded strings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that system tests have been added, but can we add also unit tests to https://github.com/googleapis/python-aiplatform/blob/main/tests/unit/aiplatform/test_models.py for raw_predict()
to ensure the function works as intended during future development.
google/cloud/aiplatform/models.py
Outdated
@@ -35,9 +36,11 @@ | |||
from google.api_core import operation | |||
from google.api_core import exceptions as api_exceptions | |||
from google.auth import credentials as auth_credentials | |||
from google.auth.transport.requests import AuthorizedSession |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import module instead of class. Because of the import name conflict maybe something like from google.auth.transport import requests as google_auth_requests
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated import statement
google/cloud/aiplatform/models.py
Outdated
@@ -1505,29 +1511,67 @@ def predict( | |||
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] | |||
``parameters_schema_uri``. | |||
timeout (float): Optional. The timeout for this request in seconds. | |||
use_raw_predict (bool): | |||
Optional. If set to True, the underlying prediction call will be made | |||
against Endpoint.raw_predict(). Currently, model version information will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Maybe mention that the default is
False
in the docstring. - "Currently..." <- does this mean that later it will become available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove "currently". I have checked with the service team and they mentioned that it's not on the roadmap to make model version info available for raw_predict due to performance reasons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated doc strings
google/cloud/aiplatform/models.py
Outdated
return Prediction( | ||
predictions=response_text["predictions"], | ||
deployed_model_id=raw_predict_response.headers[ | ||
"X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
google/cloud/aiplatform/models.py
Outdated
body=json.dumps({"instances": instances, "parameters": parameters}), | ||
headers={"Content-Type": "application/json"}, | ||
) | ||
response_text = json.loads(raw_predict_response.text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After calling json.loads
this is no longer a text. Please rename the variable response_text
to something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated variable name
google/cloud/aiplatform/models.py
Outdated
body (bytes): | ||
Required. The body of the prediction request in bytes. This must not exceed 1.5 mb per request. | ||
headers (Dict[str, str]): | ||
Required. The header of the request as a dictionary. There are no restrictions on the header. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a default value so this is not actually required, contrary to what the docstring says.
On the other hand the body
argument probably should not have a default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, neither of these args should have any defaults. I've updated the function signatures and doc strings.
google/cloud/aiplatform/models.py
Outdated
model_resource_name=prediction_response.model, | ||
) | ||
def raw_predict( | ||
self, body: bytes = None, headers: Dict[str, str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to allow None
, the typing annotation should be Optional[Dict[str, str]]
.
google/cloud/aiplatform/models.py
Outdated
@@ -2062,6 +2106,30 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict | |||
deployed_model_id=self._gca_resource.deployed_models[0].id, | |||
) | |||
|
|||
def raw_predict( | |||
self, body: bytes = None, headers: Dict[str, str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments about the arguments' default values here.
google/cloud/aiplatform/models.py
Outdated
) -> requests.models.Response: | ||
"""Make a prediction request using arbitrary headers. | ||
This method must be called within the network the PrivateEndpoint is peered to. | ||
The function call will fail otherwise. To check, use `PrivateEndpoint.network`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know the specific error that would be raised? If so perhaps add a Raises:
section to the docstring: https://google.github.io/styleguide/pyguide.html#doc-function-raises
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll test out a few scenarios, but most common one will be auth errors raised from google.api_core.exceptions
with error code 401
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant specifically the part “the function call will fail” in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I updated the doc strings in PrivateEndpoint.predict
and PrivateEndpoint.raw_predict
to contain information about the error code. Most common error code will be 404 with a message saying "request not found".
{"Content-Type": "application/json"}, | ||
) | ||
assert raw_prediction_response.status_code == 200 | ||
assert len(json.loads(raw_prediction_response.text).items()) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.items()
is not needed if we are only checking the dictionary's length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
google/cloud/aiplatform/models.py
Outdated
@@ -69,6 +72,8 @@ | |||
_DEFAULT_MACHINE_TYPE = "n1-standard-2" | |||
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0" | |||
_SUCCESSFUL_HTTP_RESPONSE = 300 | |||
_RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably call this ”deployed model id” instead of “model id” (which sometimes refers to the last part of the resource full name).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated variable name
self.authorized_session = None | ||
self.raw_predict_request_url = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these attributes be public?
"_authorized_session", | ||
"_raw_predict_request_url", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These attribute names start with underscores, but the attributes in models.py
do not have leading underscores. Is this correct?
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
Fixes #<issue_number_goes_here> 🦕