Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added params to invocation methods too #106

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 90 additions & 45 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from google.cloud.aiplatform import telemetry
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -29,6 +29,7 @@
class _BaseImageTextModel(BaseModel):
"""Base class for all integrations that use ImageTextModel"""

cached_client: Any = Field(default=None)
model_name: str = Field(default="imagetext@001")
""" Name of the model to use"""
number_of_results: int = Field(default=1)
Expand All @@ -38,9 +39,13 @@ class _BaseImageTextModel(BaseModel):
project: Union[str, None] = Field(default=None)
"""Google cloud project"""

def _create_model(self) -> ImageTextModel:
"""Builds the model object from the class attributes."""
return ImageTextModel.from_pretrained(model_name=self.model_name)
@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use here functools.cached_property and get rid of the field?

def client(self) -> ImageTextModel:
if self.cached_client is None:
self.cached_client = ImageTextModel.from_pretrained(
model_name=self.model_name,
)
return self.cached_client

def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None:
"""Given a message part obtain a image if the part represents it.
Expand Down Expand Up @@ -83,26 +88,42 @@ def _user_agent(self) -> str:
_, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}")
return user_agent

@property
def _default_params(self) -> Dict[str, Any]:
return {"number_of_results": self.number_of_results, "language": self.language}

def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
for key, value in kwargs.items():
if key in params and value is not None:
params[key] = value
return params


class _BaseVertexAIImageCaptioning(_BaseImageTextModel):
"""Base class for Image Captioning models."""

def _get_captions(self, image: Image) -> List[str]:
def _get_captions(
self,
image: Image,
number_of_results: Optional[int] = None,
language: Optional[str] = None,
) -> List[str]:
"""Uses the sdk methods to generate a list of captions.

Args:
image: Image to get the captions for.
number_of_results: Number of results to return from one query.
language: Language of the query.

Returns:
List of captions obtained from the image.
"""
with telemetry.tool_context_manager(self._user_agent):
model = self._create_model()
captions = model.get_captions(
image=image,
number_of_results=self.number_of_results,
language=self.language,
params = self._prepare_params(
number_of_results=number_of_results, language=language
)
captions = self.client.get_captions(image=image, **params)
return captions


Expand Down Expand Up @@ -130,11 +151,13 @@ def _generate(
Captions generated from every prompt.
"""

generations = [self._generate_one(prompt=prompt) for prompt in prompts]
generations = [
self._generate_one(prompt=prompt, **kwargs) for prompt in prompts
]

return LLMResult(generations=generations)

def _generate_one(self, prompt: str) -> List[Generation]:
def _generate_one(self, prompt: str, **kwargs) -> List[Generation]:
"""Generates the captions for a single prompt.

Args:
Expand All @@ -146,7 +169,7 @@ def _generate_one(self, prompt: str) -> List[Generation]:
image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(prompt)
image = Image(image_bytes=image_bytes)
caption_list = self._get_captions(image=image)
caption_list = self._get_captions(image=image, **kwargs)
return [Generation(text=caption) for caption in caption_list]


Expand Down Expand Up @@ -199,7 +222,7 @@ def _generate(
"{'type': 'image_url', 'image_url': {'image': <image_str>}}"
)

captions = self._get_captions(image)
captions = self._get_captions(image, **messages[0].additional_kwargs)

generations = [
ChatGeneration(message=AIMessage(content=caption)) for caption in captions
Expand All @@ -211,6 +234,10 @@ def _generate(
class VertexAIVisualQnAChat(_BaseImageTextModel, BaseChatModel):
"""Chat implementation of a visual QnA model"""

@property
def _default_params(self) -> Dict[str, Any]:
return {"number_of_results": self.number_of_results}

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -257,15 +284,19 @@ def _generate(
"or a dictionary with format {'type': 'text', 'text': <message>}"
)

answers = self._ask_questions(image=image, query=user_question)
answers = self._ask_questions(
image=image, query=user_question, **messages[0].additional_kwargs
)

generations = [
ChatGeneration(message=AIMessage(content=answer)) for answer in answers
]

return ChatResult(generations=generations)

def _ask_questions(self, image: Image, query: str) -> List[str]:
def _ask_questions(
self, image: Image, query: str, number_of_results: Optional[int] = None
) -> List[str]:
"""Interfaces with the sdk to get the question.

Args:
Expand All @@ -276,22 +307,21 @@ def _ask_questions(self, image: Image, query: str) -> List[str]:
List of responses to the query.
"""
with telemetry.tool_context_manager(self._user_agent):
model = self._create_model()
answers = model.ask_question(
image=image, question=query, number_of_results=self.number_of_results
)
params = self._prepare_params(number_of_results=number_of_results)
answers = self.client.ask_question(image=image, question=query, **params)
return answers


class _BaseVertexAIImageGenerator(BaseModel):
"""Base class form generation and edition of images."""

cached_client: Any = Field(default=None)
model_name: str = Field(default="imagegeneration@002")
"""Name of the base model"""
negative_prompt: Union[str, None] = Field(default=None)
"""A description of what you want to omit in
the generated images"""
number_of_images: int = Field(default=1)
number_of_results: int = Field(default=1)
"""Number of images to generate"""
guidance_scale: Union[float, None] = Field(default=None)
"""Controls the strength of the prompt"""
Expand All @@ -304,7 +334,34 @@ class _BaseVertexAIImageGenerator(BaseModel):
project: Union[str, None] = Field(default=None)
"""Google cloud project id"""

def _generate_images(self, prompt: str) -> List[str]:
@property
def client(self) -> ImageGenerationModel:
if not self.cached_client:
self.cached_client = ImageGenerationModel.from_pretrained(
model_name=self.model_name,
)
return self.cached_client

@property
def _default_params(self) -> Dict[str, Any]:
return {
"number_of_images": self.number_of_results,
"language": self.language,
"negative_prompt": self.negative_prompt,
"guidance_scale": self.guidance_scale,
"seed": self.seed,
}

def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
mapping = {"number_of_results": "number_of_images"}
for key, value in kwargs.items():
key = mapping.get(key, key)
if key in params and value is not None:
params[key] = value
return {k: v for k, v in params.items() if v is not None}

def _generate_images(self, prompt: str, **kwargs: Any) -> List[str]:
"""Generates images given a prompt.

Args:
Expand All @@ -314,15 +371,8 @@ def _generate_images(self, prompt: str) -> List[str]:
List of b64 encoded strings.
"""
with telemetry.tool_context_manager(self._user_agent):
model = ImageGenerationModel.from_pretrained(self.model_name)

generation_result = model.generate_images(
prompt=prompt,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
generation_result = self.client.generate_images(
prompt=prompt, **self._prepare_params(**kwargs)
)

image_str_list = [
Expand All @@ -331,7 +381,7 @@ def _generate_images(self, prompt: str) -> List[str]:

return image_str_list

def _edit_images(self, image_str: str, prompt: str) -> List[str]:
def _edit_images(self, image_str: str, prompt: str, **kwargs: Any) -> List[str]:
"""Edit an image given a image and a prompt.

Args:
Expand All @@ -342,20 +392,11 @@ def _edit_images(self, image_str: str, prompt: str) -> List[str]:
List of b64 encoded strings.
"""
with telemetry.tool_context_manager(self._user_agent):
model = ImageGenerationModel.from_pretrained(self.model_name)

image_loader = ImageBytesLoader(project=self.project)
image_bytes = image_loader.load_bytes(image_str)
image = Image(image_bytes=image_bytes)

generation_result = model.edit_image(
prompt=prompt,
base_image=image,
negative_prompt=self.negative_prompt,
number_of_images=self.number_of_images,
language=self.language,
guidance_scale=self.guidance_scale,
seed=self.seed,
generation_result = self.client.edit_image(
prompt=prompt, base_image=image, **self._prepare_params(**kwargs)
)

image_str_list = [
Expand Down Expand Up @@ -427,7 +468,9 @@ def _generate(
" Must The prompt of the image"
)

image_str_list = self._generate_images(prompt=user_query)
image_str_list = self._generate_images(
prompt=user_query, **messages[0].additional_kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
Expand Down Expand Up @@ -474,7 +517,9 @@ def _generate(
"two parts: First the image and then the user prompt."
)

image_str_list = self._edit_images(image_str=image_str, prompt=user_query)
image_str_list = self._edit_images(
image_str=image_str, prompt=user_query, **messages[0].additional_kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
for image_str in image_str_list
Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/tests/integration_tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def test_vertex_ai_image_captioning(base64_image: str):
response = model.invoke(base64_image)
assert isinstance(response, str)

response = model.invoke(base64_image, language="de")
assert isinstance(response, str)


@pytest.mark.release
def test_vertex_ai_visual_qna_chat(base64_image: str):
Expand Down
Loading