diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index ba0d43bd..8c4ff4f0 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from google.cloud.aiplatform import telemetry from langchain_core.callbacks import CallbackManagerForLLMRun @@ -29,6 +29,7 @@ class _BaseImageTextModel(BaseModel): """Base class for all integrations that use ImageTextModel""" + cached_client: Any = Field(default=None) model_name: str = Field(default="imagetext@001") """ Name of the model to use""" number_of_results: int = Field(default=1) @@ -38,9 +39,13 @@ class _BaseImageTextModel(BaseModel): project: Union[str, None] = Field(default=None) """Google cloud project""" - def _create_model(self) -> ImageTextModel: - """Builds the model object from the class attributes.""" - return ImageTextModel.from_pretrained(model_name=self.model_name) + @property + def client(self) -> ImageTextModel: + if self.cached_client is None: + self.cached_client = ImageTextModel.from_pretrained( + model_name=self.model_name, + ) + return self.cached_client def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None: """Given a message part obtain a image if the part represents it. @@ -83,26 +88,42 @@ def _user_agent(self) -> str: _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") return user_agent + @property + def _default_params(self) -> Dict[str, Any]: + return {"number_of_results": self.number_of_results, "language": self.language} + + def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]: + params = self._default_params + for key, value in kwargs.items(): + if key in params and value is not None: + params[key] = value + return params + class _BaseVertexAIImageCaptioning(_BaseImageTextModel): """Base class for Image Captioning models.""" - def _get_captions(self, image: Image) -> List[str]: + def _get_captions( + self, + image: Image, + number_of_results: Optional[int] = None, + language: Optional[str] = None, + ) -> List[str]: """Uses the sdk methods to generate a list of captions. Args: image: Image to get the captions for. + number_of_results: Number of results to return from one query. + language: Language of the query. Returns: List of captions obtained from the image. """ with telemetry.tool_context_manager(self._user_agent): - model = self._create_model() - captions = model.get_captions( - image=image, - number_of_results=self.number_of_results, - language=self.language, + params = self._prepare_params( + number_of_results=number_of_results, language=language ) + captions = self.client.get_captions(image=image, **params) return captions @@ -130,11 +151,13 @@ def _generate( Captions generated from every prompt. """ - generations = [self._generate_one(prompt=prompt) for prompt in prompts] + generations = [ + self._generate_one(prompt=prompt, **kwargs) for prompt in prompts + ] return LLMResult(generations=generations) - def _generate_one(self, prompt: str) -> List[Generation]: + def _generate_one(self, prompt: str, **kwargs) -> List[Generation]: """Generates the captions for a single prompt. Args: @@ -146,7 +169,7 @@ def _generate_one(self, prompt: str) -> List[Generation]: image_loader = ImageBytesLoader(project=self.project) image_bytes = image_loader.load_bytes(prompt) image = Image(image_bytes=image_bytes) - caption_list = self._get_captions(image=image) + caption_list = self._get_captions(image=image, **kwargs) return [Generation(text=caption) for caption in caption_list] @@ -199,7 +222,7 @@ def _generate( "{'type': 'image_url', 'image_url': {'image': }}" ) - captions = self._get_captions(image) + captions = self._get_captions(image, **messages[0].additional_kwargs) generations = [ ChatGeneration(message=AIMessage(content=caption)) for caption in captions @@ -211,6 +234,10 @@ def _generate( class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel): """Chat implementation of a visual QnA model""" + @property + def _default_params(self) -> Dict[str, Any]: + return {"number_of_results": self.number_of_results} + def _generate( self, messages: List[BaseMessage], @@ -257,7 +284,9 @@ def _generate( "or a dictionary with format {'type': 'text', 'text': }" ) - answers = self._ask_questions(image=image, query=user_question) + answers = self._ask_questions( + image=image, query=user_question, **messages[0].additional_kwargs + ) generations = [ ChatGeneration(message=AIMessage(content=answer)) for answer in answers @@ -265,7 +294,9 @@ def _generate( return ChatResult(generations=generations) - def _ask_questions(self, image: Image, query: str) -> List[str]: + def _ask_questions( + self, image: Image, query: str, number_of_results: Optional[int] = None + ) -> List[str]: """Interfaces with the sdk to get the question. Args: @@ -276,22 +307,21 @@ def _ask_questions(self, image: Image, query: str) -> List[str]: List of responses to the query. """ with telemetry.tool_context_manager(self._user_agent): - model = self._create_model() - answers = model.ask_question( - image=image, question=query, number_of_results=self.number_of_results - ) + params = self._prepare_params(number_of_results=number_of_results) + answers = self.client.ask_question(image=image, question=query, **params) return answers class _BaseVertexAIImageGenerator(BaseModel): """Base class form generation and edition of images.""" + cached_client: Any = Field(default=None) model_name: str = Field(default="imagegeneration@002") """Name of the base model""" negative_prompt: Union[str, None] = Field(default=None) """A description of what you want to omit in the generated images""" - number_of_images: int = Field(default=1) + number_of_results: int = Field(default=1) """Number of images to generate""" guidance_scale: Union[float, None] = Field(default=None) """Controls the strength of the prompt""" @@ -304,7 +334,34 @@ class _BaseVertexAIImageGenerator(BaseModel): project: Union[str, None] = Field(default=None) """Google cloud project id""" - def _generate_images(self, prompt: str) -> List[str]: + @property + def client(self) -> ImageGenerationModel: + if not self.cached_client: + self.cached_client = ImageGenerationModel.from_pretrained( + model_name=self.model_name, + ) + return self.cached_client + + @property + def _default_params(self) -> Dict[str, Any]: + return { + "number_of_images": self.number_of_results, + "language": self.language, + "negative_prompt": self.negative_prompt, + "guidance_scale": self.guidance_scale, + "seed": self.seed, + } + + def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]: + params = self._default_params + mapping = {"number_of_results": "number_of_images"} + for key, value in kwargs.items(): + key = mapping.get(key, key) + if key in params and value is not None: + params[key] = value + return {k: v for k, v in params.items() if v is not None} + + def _generate_images(self, prompt: str, **kwargs: Any) -> List[str]: """Generates images given a prompt. Args: @@ -314,15 +371,8 @@ def _generate_images(self, prompt: str) -> List[str]: List of b64 encoded strings. """ with telemetry.tool_context_manager(self._user_agent): - model = ImageGenerationModel.from_pretrained(self.model_name) - - generation_result = model.generate_images( - prompt=prompt, - negative_prompt=self.negative_prompt, - number_of_images=self.number_of_images, - language=self.language, - guidance_scale=self.guidance_scale, - seed=self.seed, + generation_result = self.client.generate_images( + prompt=prompt, **self._prepare_params(**kwargs) ) image_str_list = [ @@ -331,7 +381,7 @@ def _generate_images(self, prompt: str) -> List[str]: return image_str_list - def _edit_images(self, image_str: str, prompt: str) -> List[str]: + def _edit_images(self, image_str: str, prompt: str, **kwargs: Any) -> List[str]: """Edit an image given a image and a prompt. Args: @@ -342,20 +392,11 @@ def _edit_images(self, image_str: str, prompt: str) -> List[str]: List of b64 encoded strings. """ with telemetry.tool_context_manager(self._user_agent): - model = ImageGenerationModel.from_pretrained(self.model_name) - image_loader = ImageBytesLoader(project=self.project) image_bytes = image_loader.load_bytes(image_str) image = Image(image_bytes=image_bytes) - - generation_result = model.edit_image( - prompt=prompt, - base_image=image, - negative_prompt=self.negative_prompt, - number_of_images=self.number_of_images, - language=self.language, - guidance_scale=self.guidance_scale, - seed=self.seed, + generation_result = self.client.edit_image( + prompt=prompt, base_image=image, **self._prepare_params(**kwargs) ) image_str_list = [ @@ -427,7 +468,9 @@ def _generate( " Must The prompt of the image" ) - image_str_list = self._generate_images(prompt=user_query) + image_str_list = self._generate_images( + prompt=user_query, **messages[0].additional_kwargs + ) image_content_part_list = [ create_image_content_part(image_str=image_str) for image_str in image_str_list @@ -474,7 +517,9 @@ def _generate( "two parts: First the image and then the user prompt." ) - image_str_list = self._edit_images(image_str=image_str, prompt=user_query) + image_str_list = self._edit_images( + image_str=image_str, prompt=user_query, **messages[0].additional_kwargs + ) image_content_part_list = [ create_image_content_part(image_str=image_str) for image_str in image_str_list diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py index 1cbdfdda..e303de60 100644 --- a/libs/vertexai/tests/integration_tests/test_vision_models.py +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -50,6 +50,9 @@ def test_vertex_ai_image_captioning(base64_image: str): response = model.invoke(base64_image) assert isinstance(response, str) + response = model.invoke(base64_image, language="de") + assert isinstance(response, str) + @pytest.mark.release def test_vertex_ai_visual_qna_chat(base64_image: str):