Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support raw_predict for Endpoint #1620

Merged
merged 33 commits into from
Sep 2, 2022
Merged

feat: support raw_predict for Endpoint #1620

merged 33 commits into from
Sep 2, 2022

Conversation

rosiezou
Copy link
Contributor

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:

  • Make sure to open an issue as a bug/issue before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea
  • Ensure the tests and linter pass
  • Code coverage does not decrease (if any source code was changed)
  • Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕

@rosiezou rosiezou requested review from a team as code owners August 30, 2022 01:42
@product-auto-label product-auto-label bot added size: m Pull request size is medium. api: vertex-ai Issues related to the googleapis/python-aiplatform API. labels Aug 30, 2022
@rosiezou rosiezou self-assigned this Aug 30, 2022
@gcf-owl-bot gcf-owl-bot bot requested review from a team as code owners August 31, 2022 05:05
@gcf-owl-bot gcf-owl-bot bot requested a review from dandhlee August 31, 2022 05:05
@product-auto-label product-auto-label bot added size: l Pull request size is large. and removed size: m Pull request size is medium. labels Aug 31, 2022
@product-auto-label product-auto-label bot added size: m Pull request size is medium. and removed size: l Pull request size is large. labels Aug 31, 2022
Copy link
Contributor

@nayaknishant nayaknishant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is test_model_interactions.py testing permanent models? If it is I think we should keep it as two separate files (as you have it), but if it's using temp models created in test_model_upload.py wouldn't it make more sense to have a predict() call and then a raw_predict() call after in the same file (maybe change the name to a more broad test_model.py ?

return Prediction(
predictions=response_text["predictions"],
deployed_model_id=raw_predict_response.headers[
"X-Vertex-AI-Deployed-Model-Id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, if you can clarify @rosiezou, I'm guessing X-Vertex-AI-Deployed-Model-Id is populated when the response is returned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added constants (_RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" and _RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model") to replace the hardcoded strings

Copy link
Contributor

@nayaknishant nayaknishant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that system tests have been added, but can we add also unit tests to https://github.com/googleapis/python-aiplatform/blob/main/tests/unit/aiplatform/test_models.py for raw_predict() to ensure the function works as intended during future development.

google/cloud/aiplatform/constants/base.py Show resolved Hide resolved
@@ -35,9 +36,11 @@
from google.api_core import operation
from google.api_core import exceptions as api_exceptions
from google.auth import credentials as auth_credentials
from google.auth.transport.requests import AuthorizedSession
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import module instead of class. Because of the import name conflict maybe something like from google.auth.transport import requests as google_auth_requests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated import statement

@@ -1505,29 +1511,67 @@ def predict(
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
timeout (float): Optional. The timeout for this request in seconds.
use_raw_predict (bool):
Optional. If set to True, the underlying prediction call will be made
against Endpoint.raw_predict(). Currently, model version information will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Maybe mention that the default is False in the docstring.
  2. "Currently..." <- does this mean that later it will become available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove "currently". I have checked with the service team and they mentioned that it's not on the roadmap to make model version info available for raw_predict due to performance reasons.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated doc strings

return Prediction(
predictions=response_text["predictions"],
deployed_model_id=raw_predict_response.headers[
"X-Vertex-AI-Deployed-Model-Id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

body=json.dumps({"instances": instances, "parameters": parameters}),
headers={"Content-Type": "application/json"},
)
response_text = json.loads(raw_predict_response.text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After calling json.loads this is no longer a text. Please rename the variable response_text to something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated variable name

body (bytes):
Required. The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
Required. The header of the request as a dictionary. There are no restrictions on the header.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a default value so this is not actually required, contrary to what the docstring says.

On the other hand the body argument probably should not have a default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, neither of these args should have any defaults. I've updated the function signatures and doc strings.

model_resource_name=prediction_response.model,
)
def raw_predict(
self, body: bytes = None, headers: Dict[str, str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to allow None, the typing annotation should be Optional[Dict[str, str]].

@@ -2062,6 +2106,30 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
deployed_model_id=self._gca_resource.deployed_models[0].id,
)

def raw_predict(
self, body: bytes = None, headers: Dict[str, str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments about the arguments' default values here.

) -> requests.models.Response:
"""Make a prediction request using arbitrary headers.
This method must be called within the network the PrivateEndpoint is peered to.
The function call will fail otherwise. To check, use `PrivateEndpoint.network`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know the specific error that would be raised? If so perhaps add a Raises: section to the docstring: https://google.github.io/styleguide/pyguide.html#doc-function-raises

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test out a few scenarios, but most common one will be auth errors raised from google.api_core.exceptions with error code 401

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant specifically the part “the function call will fail” in the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I updated the doc strings in PrivateEndpoint.predict and PrivateEndpoint.raw_predict to contain information about the error code. Most common error code will be 404 with a message saying "request not found".

{"Content-Type": "application/json"},
)
assert raw_prediction_response.status_code == 200
assert len(json.loads(raw_prediction_response.text).items()) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.items() is not needed if we are only checking the dictionary's length.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@@ -69,6 +72,8 @@
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
_SUCCESSFUL_HTTP_RESPONSE = 300
_RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably call this ”deployed model id” instead of “model id” (which sometimes refers to the last part of the resource full name).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated variable name

@product-auto-label product-auto-label bot added size: l Pull request size is large. and removed size: m Pull request size is medium. labels Sep 2, 2022
@rosiezou rosiezou merged commit cc7c968 into main Sep 2, 2022
@rosiezou rosiezou deleted the raw-predict branch September 2, 2022 22:34
Comment on lines +208 to +209
self.authorized_session = None
self.raw_predict_request_url = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these attributes be public?

Comment on lines +979 to +980
"_authorized_session",
"_raw_predict_request_url",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These attribute names start with underscores, but the attributes in models.py do not have leading underscores. Is this correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api: vertex-ai Issues related to the googleapis/python-aiplatform API. size: l Pull request size is large.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants