Skip to content

Commit

Permalink
added params to invocation methods too (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Apr 3, 2024
1 parent fe2968c commit fef05cd
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 45 deletions.
135 changes: 90 additions & 45 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from google.cloud.aiplatform import telemetry
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -29,6 +29,7 @@
class _BaseImageTextModel(BaseModel):
"""Base class for all integrations that use ImageTextModel"""

cached_client: Any = Field(default=None)
model_name: str = Field(default="imagetext@001")
""" Name of the model to use"""
number_of_results: int = Field(default=1)
Expand All @@ -38,9 +39,13 @@ class _BaseImageTextModel(BaseModel):
project: Union[str, None] = Field(default=None)
"""Google cloud project"""

def _create_model(self) -> ImageTextModel:
"""Builds the model object from the class attributes."""
return ImageTextModel.from_pretrained(model_name=self.model_name)
@property
def client(self) -> ImageTextModel:
if self.cached_client is None:
self.cached_client = ImageTextModel.from_pretrained(
model_name=self.model_name,
)
return self.cached_client

def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None:
"""Given a message part obtain a image if the part represents it.
Expand Down Expand Up @@ -83,26 +88,42 @@ def _user_agent(self) -> str:
_, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
return user_agent

@property
def _default_params(self) -> Dict[str, Any]:
return {"number_of_results": self.number_of_results, "language": self.language}

def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
for key, value in kwargs.items():
if key in params and value is not None:
params[key] = value
return params


class _BaseVertexAIImageCaptioning(_BaseImageTextModel):
"""Base class for Image Captioning models."""

def _get_captions(self, image: Image) -> List[str]:
def _get_captions(
self,
image: Image,
number_of_results: Optional[int] = None,
language: Optional[str] = None,
) -> List[str]:
"""Uses the sdk methods to generate a list of captions.
Args:
image: Image to get the captions for.
number_of_results: Number of results to return from one query.
language: Language of the query.
Returns:
List of captions obtained from the image.
"""
with telemetry.tool_context_manager(self._user_agent):
model = self._create_model()
captions = model.get_captions(
image=image,
number_of_results=self.number_of_results,
language=self.language,
params = self._prepare_params(
number_of_results=number_of_results, language=language
)
captions = self.client.get_captions(image=image, **params)
return captions


Expand Down Expand Up @@ -130,11 +151,13 @@ def _generate(
Captions generated from every prompt.
"""

generations = [self._generate_one(prompt=prompt) for prompt in prompts]
generations = [
self._generate_one(prompt=prompt, **kwargs) for prompt in prompts
]

return LLMResult(generations=generations)

def _generate_one(self, prompt: str) -> List[Generation]:
def _generate_one(self, prompt: str, **kwargs) -> List[Generation]:
"""Generates the captions for a single prompt.
Args:
Expand All @@ -146,7 +169,7 @@ def _generate_one(self, prompt: str) -> List[Generation]:
image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(prompt)
image = Image(image_bytes=image_bytes)
caption_list = self._get_captions(image=image)
caption_list = self._get_captions(image=image, **kwargs)
return [Generation(text=caption) for caption in caption_list]


Expand Down Expand Up @@ -199,7 +222,7 @@ def _generate(
"{'type': 'image_url', 'image_url': {'image': <image_str>}}"
)

captions = self._get_captions(image)
captions = self._get_captions(image, **messages[0].additional_kwargs)

generations = [
ChatGeneration(message=AIMessage(content=caption)) for caption in captions
Expand All @@ -211,6 +234,10 @@ def _generate(
class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel):
"""Chat implementation of a visual QnA model"""

@property
def _default_params(self) -> Dict[str, Any]:
return {"number_of_results": self.number_of_results}

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -257,15 +284,19 @@ def _generate(
"or a dictionary with format {'type': 'text', 'text': <message>}"
)

answers = self._ask_questions(image=image, query=user_question)
answers = self._ask_questions(
image=image, query=user_question, **messages[0].additional_kwargs
)

generations = [
ChatGeneration(message=AIMessage(content=answer)) for answer in answers
]

return ChatResult(generations=generations)

def _ask_questions(self, image: Image, query: str) -> List[str]:
def _ask_questions(
self, image: Image, query: str, number_of_results: Optional[int] = None
) -> List[str]:
"""Interfaces with the sdk to get the question.
Args:
Expand All @@ -276,22 +307,21 @@ def _ask_questions(self, image: Image, query: str) -> List[str]:
List of responses to the query.
"""
with telemetry.tool_context_manager(self._user_agent):
model = self._create_model()
answers = model.ask_question(
image=image, question=query, number_of_results=self.number_of_results
)
params = self._prepare_params(number_of_results=number_of_results)
answers = self.client.ask_question(image=image, question=query, **params)
return answers


class _BaseVertexAIImageGenerator(BaseModel):
"""Base class form generation and edition of images."""

cached_client: Any = Field(default=None)
model_name: str = Field(default="imagegeneration@002")
"""Name of the base model"""
negative_prompt: Union[str, None] = Field(default=None)
"""A description of what you want to omit in
the generated images"""
number_of_images: int = Field(default=1)
number_of_results: int = Field(default=1)
"""Number of images to generate"""
guidance_scale: Union[float, None] = Field(default=None)
"""Controls the strength of the prompt"""
Expand All @@ -304,7 +334,34 @@ class _BaseVertexAIImageGenerator(BaseModel):
project: Union[str, None] = Field(default=None)
"""Google cloud project id"""

def _generate_images(self, prompt: str) -> List[str]:
@property
def client(self) -> ImageGenerationModel:
if not self.cached_client:
self.cached_client = ImageGenerationModel.from_pretrained(
model_name=self.model_name,
)
return self.cached_client

@property
def _default_params(self) -> Dict[str, Any]:
return {
"number_of_images": self.number_of_results,
"language": self.language,
"negative_prompt": self.negative_prompt,
"guidance_scale": self.guidance_scale,
"seed": self.seed,
}

def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
mapping = {"number_of_results": "number_of_images"}
for key, value in kwargs.items():
key = mapping.get(key, key)
if key in params and value is not None:
params[key] = value
return {k: v for k, v in params.items() if v is not None}

def _generate_images(self, prompt: str, **kwargs: Any) -> List[str]:
"""Generates images given a prompt.
Args:
Expand All @@ -314,15 +371,8 @@ def _generate_images(self, prompt: str) -> List[str]:
List of b64 encoded strings.
"""
with telemetry.tool_context_manager(self._user_agent):
model = ImageGenerationModel.from_pretrained(self.model_name)

generation_result = model.generate_images(
prompt=prompt,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
generation_result = self.client.generate_images(
prompt=prompt, **self._prepare_params(**kwargs)
)

image_str_list = [
Expand All @@ -331,7 +381,7 @@ def _generate_images(self, prompt: str) -> List[str]:

return image_str_list

def _edit_images(self, image_str: str, prompt: str) -> List[str]:
def _edit_images(self, image_str: str, prompt: str, **kwargs: Any) -> List[str]:
"""Edit an image given a image and a prompt.
Args:
Expand All @@ -342,20 +392,11 @@ def _edit_images(self, image_str: str, prompt: str) -> List[str]:
List of b64 encoded strings.
"""
with telemetry.tool_context_manager(self._user_agent):
model = ImageGenerationModel.from_pretrained(self.model_name)

image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(image_str)
image = Image(image_bytes=image_bytes)

generation_result = model.edit_image(
prompt=prompt,
base_image=image,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
generation_result = self.client.edit_image(
prompt=prompt, base_image=image, **self._prepare_params(**kwargs)
)

image_str_list = [
Expand Down Expand Up @@ -427,7 +468,9 @@ def _generate(
" Must The prompt of the image"
)

image_str_list = self._generate_images(prompt=user_query)
image_str_list = self._generate_images(
prompt=user_query, **messages[0].additional_kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
Expand Down Expand Up @@ -474,7 +517,9 @@ def _generate(
"two parts: First the image and then the user prompt."
)

image_str_list = self._edit_images(image_str=image_str, prompt=user_query)
image_str_list = self._edit_images(
image_str=image_str, prompt=user_query, **messages[0].additional_kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/tests/integration_tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def test_vertex_ai_image_captioning(base64_image: str):
response = model.invoke(base64_image)
assert isinstance(response, str)

response = model.invoke(base64_image, language="de")
assert isinstance(response, str)


@pytest.mark.release
def test_vertex_ai_visual_qna_chat(base64_image: str):
Expand Down

0 comments on commit fef05cd

Please sign in to comment.