-
Notifications
You must be signed in to change notification settings - Fork 16.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
core[minor]: support pydantic v2 models in PydanticOutputParser #18811
Changes from 4 commits
b1af8b5
20df86e
c2b3211
90e25ea
d870dc2
c7ea1b6
4fdc8dc
de98790
748392a
6bff30a
de798dc
7121c93
7981d34
2174c54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import json | ||
from typing import Generic, List, Type, TypeVar | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.output_parsers.json import JsonOutputParser | ||
from langchain_core.output_parsers.pydantic import _PYDANTIC_FORMAT_INSTRUCTIONS | ||
from langchain_core.outputs.generation import Generation | ||
from langchain_core.pydantic_v2 import BaseModel, ValidationError | ||
|
||
TBaseModel = TypeVar("TBaseModel", bound=BaseModel) | ||
|
||
|
||
class PydanticV2OutputParser(JsonOutputParser, Generic[TBaseModel]): | ||
"""Parse an output using a pydantic model.""" | ||
|
||
pydantic_v2_object: Type[TBaseModel] | ||
"""The pydantic model to parse.""" | ||
|
||
def parse_result( | ||
self, result: List[Generation], *, partial: bool = False | ||
) -> TBaseModel: | ||
json_object = super().parse_result(result) | ||
try: | ||
return self.pydantic_v2_object.model_validate(json_object) | ||
except ValidationError as e: | ||
name = self.pydantic_v2_object.__name__ | ||
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" | ||
raise OutputParserException(msg, llm_output=json_object) | ||
|
||
def parse(self, text: str) -> TBaseModel: | ||
return super().parse(text) | ||
|
||
def get_format_instructions(self) -> str: | ||
# Copy schema to avoid altering original Pydantic schema. | ||
schema = {k: v for k, v in self.pydantic_v2_object.model_json_schema().items()} | ||
|
||
# Remove extraneous fields. | ||
reduced_schema = schema | ||
if "title" in reduced_schema: | ||
del reduced_schema["title"] | ||
if "type" in reduced_schema: | ||
del reduced_schema["type"] | ||
# Ensure json in context is well-formed with double quotes. | ||
schema_str = json.dumps(reduced_schema) | ||
|
||
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str) | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "pydantic" | ||
|
||
@property | ||
def OutputType(self) -> Type[TBaseModel]: | ||
"""Return the pydantic model.""" | ||
return self.pydantic_v2_object |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import warnings | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced that the namespace is needed yet -- could you review the code snippets -- i think some form of either those versions or else local scope imports will work |
||
|
||
from langchain_core.utils.pydantic import get_pydantic_major_version | ||
|
||
## Create namespaces for pydantic v1 and v2. | ||
# This code must stay at the top of the file before other modules may | ||
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules. | ||
# | ||
# This hack is done for the following reasons: | ||
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since | ||
# both dependencies and dependents may be stuck on either version of v1 or v2. | ||
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that | ||
# unambiguously uses either v1 or v2 API. | ||
# * This change is easier to roll out and roll back. | ||
|
||
|
||
if get_pydantic_major_version() < 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no side effects on import since the import is triggered by default rather than user action |
||
warnings.warn( | ||
"The pydantic_v2 namespace only supports Pydantic v2 and later. \ | ||
Please use pydantic_v1 namespace.", | ||
ImportWarning, | ||
) | ||
|
||
from pydantic import * # noqa: F403 # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit worried about getting this into the code base. Definitely into
core
as it will create duplicated code incore
when it upgrades to pydantic 2Not sure, but this could potentially create issues when using in a chain or a runnable map together since it'll cause mixing of pydantic v1 and v2 code.