Skip to content

Commit

Permalink
added support for JSON mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vishah02 committed Nov 28, 2024
1 parent 9e90ad2 commit e26010b
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from google.generativeai.caching import CachedContent # type: ignore[import]
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types import caching_types, content_types
from google.generativeai.types import generation_types # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolDict,
Expand Down Expand Up @@ -114,7 +115,7 @@
from langchain_google_genai._image_utils import ImageBytesLoader
from langchain_google_genai.llms import _BaseGoogleGenerativeAI

from . import _genai_extension as genaix
from langchain_google_genai import _genai_extension as genaix

IMAGE_TYPES: Tuple = ()
try:
Expand Down Expand Up @@ -833,6 +834,52 @@ class Joke(BaseModel):
'finish_reason': 'STOP',
'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
}
JSON Output and Schema Support:
.. code-block:: python
import typing_extensions as typing
class Recipe(typing.TypedDict):
recipe_name: str
ingredients: list[str]
rating: float
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
response_mime_type="application/json",
response_schema=list[Recipe]
)
# Define the messages
messages = [
("system", "You are a helpful assistant"),
("human", "List 2 recipes for a healthy breakfast"),
]
response = llm.invoke(messages)
print(json.dumps(json.loads(response.content), indent=4))
.. code-block:: python
[
{
"ingredients": [
"1/2 cup rolled oats",
"1 cup unsweetened almond milk",
"1/4 cup berries",
"1 tablespoon chia seeds",
"1/2 teaspoon cinnamon"
],
"rating": 4.5,
"recipe_name": "Overnight Oats"
},
]
""" # noqa: E501

Expand All @@ -853,6 +900,22 @@ class Joke(BaseModel):
Gemini does not support system messages; any unsupported messages will
raise an error."""

response_mime_type: Optional[str] = None
"""Optional. Output response mimetype of the generated candidate text. Only
supported in Gemini 1.5 and later models. Supported mimetype:
* "text/plain": (default) Text output.
* "application/json": JSON response in the candidates.
* "text/x.enum": Enum in plain text.
"""

response_schema: Optional[Any] = None
""" Optional. Enforce an schema to the output.
The value of response_schema must be a either:
* A type hint annotation, as defined in the Python typing module module.
* An instance of genai.protos.Schema.
* An enum class
"""

cached_content: Optional[str] = None
"""The name of the cached content used as context to serve the prediction.
Expand Down Expand Up @@ -892,6 +955,20 @@ def validate_environment(self) -> Self:
if not self.model.startswith("models/"):
self.model = f"models/{self.model}"

if self.response_mime_type is not None and self.response_mime_type not in [
"text/plain", "application/json", "text/x.enum"]:
raise ValueError(
"response_mime_type must be either 'text/plain' "
"or 'application/json'"
)

if self.response_schema is not None:
if self.response_mime_type not in ["application/json", "text/x.enum"]:
raise ValueError(
"response_schema is only supported when response_mime_type is "
"'application/json or 'text/x.enum'"
)

additional_headers = self.additional_headers or {}
self.default_metadata = tuple(additional_headers.items())
client_info = get_client_info("ChatGoogleGenerativeAI")
Expand Down Expand Up @@ -977,9 +1054,14 @@ def _prepare_params(
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"response_mime_type": self.response_mime_type,
"response_schema": self.response_schema,
}.items()
if v is not None
}
if gen_config.get("response_schema"):
generation_types._normalize_schema(gen_config)

if generation_config:
gen_config = {**gen_config, **generation_config}
return GenerationConfig(**gen_config)
Expand Down

0 comments on commit e26010b

Please sign in to comment.