-
Notifications
You must be signed in to change notification settings - Fork 353
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs(samples): add AutoML image classification sample (#923)
* Create predict_image_classification_sample.py * feat: new sample and test * lint: fix wsp * lint: import order * lint: fix import * tags: fixed start tag * samples: change tabular to image in sample function name. * samples: replace TF version of reading in binary file with Python version * samples: delete tf import, move other imports within region tags * Update predict_image_classification_sample.py * samples: move imports for lint * Update predict_image_classification_sample.py * Update predict_image_classification_sample_test.py Co-authored-by: Karl Weinmeister <[email protected]> Co-authored-by: Rosie Zou <[email protected]> Co-authored-by: nayaknishant <[email protected]>
- Loading branch information
1 parent
406ed84
commit 677b311
Showing
2 changed files
with
87 additions
and
0 deletions.
There are no files selected for viewing
54 changes: 54 additions & 0 deletions
54
samples/model-builder/predict_image_classification_sample.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# [START aiplatform_sdk_predict_image_classification_sample] | ||
import base64 | ||
|
||
from typing import List | ||
|
||
from google.cloud import aiplatform | ||
|
||
|
||
def predict_image_classification_sample( | ||
project: str, | ||
location: str, | ||
endpoint_name: str, | ||
images: List, | ||
): | ||
''' | ||
Args | ||
project: Your project ID or project number. | ||
location: Region where Endpoint is located. For example, 'us-central1'. | ||
endpoint_name: A fully qualified endpoint name or endpoint ID. Example: "projects/123/locations/us-central1/endpoints/456" or | ||
"456" when project and location are initialized or passed. | ||
images: A list of one or more images to return a prediction for. | ||
''' | ||
aiplatform.init(project=project, location=location) | ||
|
||
endpoint = aiplatform.Endpoint(endpoint_name) | ||
|
||
instances = [] | ||
for image in images: | ||
with open(image, "rb") as f: | ||
content = f.read() | ||
instances.append({"content": base64.b64encode(content).decode("utf-8")}) | ||
|
||
response = endpoint.predict(instances=instances) | ||
|
||
for prediction_ in response.predictions: | ||
print(prediction_) | ||
|
||
|
||
# [END aiplatform_sdk_predict_image_classification_sample] |
33 changes: 33 additions & 0 deletions
33
samples/model-builder/predict_image_classification_sample_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import predict_image_classification_sample | ||
import test_constants as constants | ||
|
||
|
||
def test_predict_image_classification_sample(mock_sdk_init, mock_get_endpoint): | ||
|
||
predict_image_classification_sample.predict_image_classification_sample( | ||
project=constants.PROJECT, | ||
location=constants.LOCATION, | ||
endpoint_name=constants.ENDPOINT_NAME, | ||
images=[] | ||
) | ||
|
||
mock_sdk_init.assert_called_once_with( | ||
project=constants.PROJECT, location=constants.LOCATION | ||
) | ||
|
||
mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,) |