diff --git a/tests/system/aiplatform/test_vision_models.py b/tests/system/aiplatform/test_vision_models.py new file mode 100644 index 0000000000..ddf7cf7168 --- /dev/null +++ b/tests/system/aiplatform/test_vision_models.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=protected-access + +import os +import tempfile + +from google.cloud import aiplatform +from tests.system.aiplatform import e2e_base +from vertexai.preview import vision_models +from PIL import Image as PIL_Image + + +def _create_blank_image( + width: int = 100, + height: int = 100, +) -> vision_models.Image: + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir, "image.png") + pil_image = PIL_Image.new(mode="RGB", size=(width, height)) + pil_image.save(image_path, format="PNG") + return vision_models.Image.load_from_file(image_path) + + +class VisionModelTestSuite(e2e_base.TestEndToEnd): + """System tests for vision models.""" + + _temp_prefix = "temp_vision_models_test_" + + def test_image_captioning_model_get_captions(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = vision_models.ImageCaptioningModel.from_pretrained("imagetext") + image = _create_blank_image() + captions = model.get_captions( + image=image, + # Optional: + number_of_results=2, + language="en", + ) + assert len(captions) == 2 + + def test_image_q_and_a_model_ask_question(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = vision_models.ImageQnAModel.from_pretrained("imagetext") + image = _create_blank_image() + answers = model.ask_question( + image=image, + question="What color is the car in this image?", + # Optional: + number_of_results=2, + ) + assert len(answers) == 2 + + def test_multi_modal_embedding_model(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = vision_models.MultiModalEmbeddingModel.from_pretrained( + "multimodalembedding@001" + ) + image = _create_blank_image() + embeddings = model.get_embeddings( + image=image, + # Optional: + contextual_text="this is a car", + ) + # The service is expected to return the embeddings of size 1408 + assert len(embeddings.image_embedding) == 1408 + assert len(embeddings.text_embedding) == 1408 diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py new file mode 100644 index 0000000000..f42228e7e8 --- /dev/null +++ b/tests/unit/aiplatform/test_vision_models.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the vision models.""" + +# pylint: disable=protected-access,bad-continuation + +import importlib +import os +import tempfile +from unittest import mock + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.compat.services import ( + model_garden_service_client, +) +from google.cloud.aiplatform.compat.services import prediction_service_client +from google.cloud.aiplatform.compat.types import ( + prediction_service as gca_prediction_service, +) +from google.cloud.aiplatform.compat.types import ( + publisher_model as gca_publisher_model, +) +from vertexai.preview import vision_models + +from PIL import Image as PIL_Image +import pytest + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + +_IMAGE_TEXT_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/imagetext", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, + "publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/imagetext@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/vision_reasoning_model_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/vision_reasoning_model_1.0.0.yaml", + }, +} + +_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/multimodalembedding", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, + "publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/multimodalembedding@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/vision_embedding_model_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/vision_embedding_model_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/vision_embedding_model_1.0.0.yaml", + }, +} + + +def generate_image_from_file( + width: int = 100, height: int = 100 +) -> vision_models.Image: + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir, "image.png") + pil_image = PIL_Image.new(mode="RGB", size=(width, height)) + pil_image.save(image_path, format="PNG") + return vision_models.Image.load_from_file(image_path) + + +@pytest.mark.usefixtures("google_auth_mock") +class ImageCaptioningModelTests: + """Unit tests for the image captioning models.""" + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_get_captions(self): + """Tests the image captioning model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model(_IMAGE_TEXT_PUBLISHER_MODEL_DICT), + ): + model = vision_models.ImageCaptioningModel.from_pretrained("imagetext@001") + + image_captions = [ + "Caption 1", + "Caption 2", + ] + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.extend(image_captions) + + with tempfile.TemporaryDirectory() as temp_dir: + image_path = os.path.join(temp_dir, "image.png") + pil_image = PIL_Image.new(mode="RGB", size=(100, 100)) + pil_image.save(image_path, format="PNG") + image = vision_models.Image.load_from_file(image_path) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + actual_captions = model.get_captions(image=image, number_of_results=2) + assert actual_captions == image_captions + + +@pytest.mark.usefixtures("google_auth_mock") +class ImageQnAModelTests: + """Unit tests for the image to text models.""" + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_get_captions(self): + """Tests the image captioning model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_TEXT_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = vision_models.ImageQnAModel.from_pretrained("imagetext@001") + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/imagetext@001", + retry=base._DEFAULT_RETRY, + ) + + image_answers = [ + "Black square", + "Black Square by Malevich", + ] + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.extend(image_answers) + + image = generate_image_from_file() + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + actual_answers = model.ask_question( + image=image, + question="What is this painting?", + number_of_results=2, + ) + assert actual_answers == image_answers + + +@pytest.mark.usefixtures("google_auth_mock") +class TestMultiModalEmbeddingModels: + """Unit tests for the image generation models.""" + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_image_embedding_model_with_only_image(self): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = vision_models.MultiModalEmbeddingModel.from_pretrained( + "multimodalembedding@001" + ) + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/multimodalembedding@001", + retry=base._DEFAULT_RETRY, + ) + + test_image_embeddings = [0, 0] + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append( + {"imageEmbedding": test_image_embeddings} + ) + + image = generate_image_from_file() + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + embedding_response = model.get_embeddings(image=image) + + assert embedding_response.image_embedding == test_image_embeddings + assert not embedding_response.text_embedding + + def test_image_embedding_model_with_image_and_text(self): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT + ), + ): + model = vision_models.MultiModalEmbeddingModel.from_pretrained( + "multimodalembedding@001" + ) + + test_embeddings = [0, 0] + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append( + {"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings} + ) + + image = generate_image_from_file() + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + embedding_response = model.get_embeddings( + image=image, contextual_text="hello world" + ) + + assert embedding_response.image_embedding == test_embeddings + assert embedding_response.text_embedding == test_embeddings diff --git a/vertexai/preview/vision_models.py b/vertexai/preview/vision_models.py new file mode 100644 index 0000000000..fb3a32fd5f --- /dev/null +++ b/vertexai/preview/vision_models.py @@ -0,0 +1,31 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with vision models.""" + +from vertexai.vision_models._vision_models import ( + Image, + ImageCaptioningModel, + ImageQnAModel, + MultiModalEmbeddingModel, + MultiModalEmbeddingResponse, +) + +__all__ = [ + "Image", + "ImageCaptioningModel", + "ImageQnAModel", + "MultiModalEmbeddingModel", + "MultiModalEmbeddingResponse", +] diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py new file mode 100644 index 0000000000..a9a7150076 --- /dev/null +++ b/vertexai/vision_models/_vision_models.py @@ -0,0 +1,291 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with vision models.""" + +import base64 +import dataclasses +import io +import pathlib +from typing import Any, List, Optional + +from vertexai._model_garden import _model_garden_models + +# pylint: disable=g-import-not-at-top +try: + from IPython import display as IPython_display +except ImportError: + IPython_display = None + +try: + from PIL import Image as PIL_Image +except ImportError: + PIL_Image = None + + +class Image: + """Image.""" + + _image_bytes: bytes + _loaded_image: Optional["PIL_Image.Image"] = None + + def __init__(self, image_bytes: bytes): + """Creates an `Image` object. + + Args: + image_bytes: Image file bytes. Image can be in PNG or JPEG format. + """ + self._image_bytes = image_bytes + + @staticmethod + def load_from_file(location: str) -> "Image": + """Loads image from file. + + Args: + location: Local path from where to load the image. + + Returns: + Loaded image as an `Image` object. + """ + image_bytes = pathlib.Path(location).read_bytes() + image = Image(image_bytes=image_bytes) + return image + + @property + def _pil_image(self) -> "PIL_Image.Image": + if self._loaded_image is None: + self._loaded_image = PIL_Image.open(io.BytesIO(self._image_bytes)) + return self._loaded_image + + @property + def _size(self): + return self._pil_image.size + + def show(self): + """Shows the image. + + This method only works when in a notebook environment. + """ + if PIL_Image and IPython_display: + IPython_display.display(self._pil_image) + + def save(self, location: str): + """Saves image to a file. + + Args: + location: Local path where to save the image. + """ + pathlib.Path(location).write_bytes(self._image_bytes) + + def _as_base64_string(self) -> str: + """Encodes image using the base64 encoding. + + Returns: + Base64 encoding of the image as a string. + """ + # ! b64encode returns `bytes` object, not ``str. + # We need to convert `bytes` to `str`, otherwise we get service error: + # "received initial metadata size exceeds limit" + return base64.b64encode(self._image_bytes).decode("ascii") + + +class ImageCaptioningModel( + _model_garden_models._ModelGardenModel # pylint: disable=protected-access +): + """Generates captions from image. + + Examples:: + + model = ImageCaptioningModel.from_pretrained("imagetext@001") + image = Image.load_from_file("image.png") + captions = model.get_captions( + image=image, + # Optional: + number_of_results=1, + language="en", + ) + """ + + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml" + _LAUNCH_STAGE = ( + _model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access + ) + + def get_captions( + self, + image: Image, + *, + number_of_results: int = 1, + language: str = "en", + ) -> List[str]: + """Generates captions for a given image. + + Args: + image: The image to get captions for. Size limit: 10 MB. + number_of_results: Number of captions to produce. Range: 1-3. + language: Language to use for captions. + Supported languages: "en", "fr", "de", "it", "es" + + Returns: + A list of image caption strings. + """ + instance = { + "image": { + "bytesBase64Encoded": image._as_base64_string() # pylint: disable=protected-access + } + } + parameters = { + "sampleCount": number_of_results, + "language": language, + } + response = self._endpoint.predict( + instances=[instance], + parameters=parameters, + ) + return response.predictions + + +class ImageQnAModel( + _model_garden_models._ModelGardenModel # pylint: disable=protected-access +): + """Answers questions about an image. + + Examples:: + + model = ImageQnAModel.from_pretrained("imagetext@001") + image = Image.load_from_file("image.png") + answers = model.ask_question( + image=image, + question="What color is the car in this image?", + # Optional: + number_of_results=1, + ) + """ + + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml" + _LAUNCH_STAGE = ( + _model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access + ) + + def ask_question( + self, + image: Image, + question: str, + *, + number_of_results: int = 1, + ) -> List[str]: + """Answers questions about an image. + + Args: + image: The image to get captions for. Size limit: 10 MB. + question: Question to ask about the image. + number_of_results: Number of captions to produce. Range: 1-3. + + Returns: + A list of answers. + """ + instance = { + "prompt": question, + "image": { + "bytesBase64Encoded": image._as_base64_string() # pylint: disable=protected-access + }, + } + parameters = { + "sampleCount": number_of_results, + } + response = self._endpoint.predict( + instances=[instance], + parameters=parameters, + ) + return response.predictions + + +class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel): + """Generates embedding vectors from images. + + Examples:: + + model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001") + image = Image.load_from_file("image.png") + + embeddings = model.get_embeddings( + image=image, + contextual_text="Hello world", + ) + image_embedding = embeddings.image_embedding + text_embedding = embeddings.text_embedding + """ + + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_embedding_model_1.0.0.yaml" + + _LAUNCH_STAGE = ( + _model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access + ) + + def get_embeddings( + self, image: Image, contextual_text: Optional[str] = None + ) -> "MultiModalEmbeddingResponse": + """Gets embedding vectors from the provided image. + + Args: + image (Image): + The image to generate embeddings for. + contextual_text (str): + Optional. Contextual text for your input image. If provided, the model will also + generate an embedding vector for the provided contextual text. The returned image + and text embedding vectors are in the same semantic space with the same dimensionality, + and the vectors can be used interchangeably for use cases like searching image by text + or searching text by image. + + Returns: + ImageEmbeddingResponse: + The image and text embedding vectors. + """ + + instance = { + "image": {"bytesBase64Encoded": image._as_base64_string()}, + "features": [{"type": "IMAGE_EMBEDDING"}], + } + + if contextual_text: + instance["text"] = contextual_text + + response = self._endpoint.predict(instances=[instance]) + image_embedding = response.predictions[0].get("imageEmbedding") + text_embedding = ( + response.predictions[0].get("textEmbedding") + if "textEmbedding" in response.predictions[0] + else None + ) + return MultiModalEmbeddingResponse( + image_embedding=image_embedding, + _prediction_response=response, + text_embedding=text_embedding, + ) + + +@dataclasses.dataclass +class MultiModalEmbeddingResponse: + """The image embedding response. + + Attributes: + image_embedding (List[float]): + The emebedding vector generated from your image. + text_embedding (List[float]): + Optional. The embedding vector generated from the contextual text provided for your image. + """ + + image_embedding: List[float] + _prediction_response: Any + text_embedding: Optional[List[float]] = None