-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
core[minor]: support pydantic v2 models in PydanticOutputParser (#18811)
As mentioned in #18322, the current PydanticOutputParser won't work for anyone trying to parse to pydantic v2 models. This PR adds a separate `PydanticV2OutputParser`, as well as a `langchain_core.pydantic_v2` namespace that will fail on import to any projects using pydantic<2. Happy to update the docs for output parsers if this is something we're interesting in adding. On a separate note, I also updated `check_pydantic.sh` to detect pydantic imports with leading whitespace and excluded the internal namespaces. That change can be separated into its own PR if needed. --------- Co-authored-by: Jan Nissen <[email protected]>
- Loading branch information
Showing
3 changed files
with
122 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Literal | ||
|
||
import pydantic # pydantic: ignore | ||
import pytest | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.language_models import ParrotFakeChatModel | ||
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel | ||
from langchain_core.prompts.prompt import PromptTemplate | ||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION | ||
|
||
V1BaseModel = pydantic.BaseModel | ||
if PYDANTIC_MAJOR_VERSION == 2: | ||
from pydantic.v1 import BaseModel # pydantic: ignore | ||
|
||
V1BaseModel = BaseModel # type: ignore | ||
|
||
|
||
class ForecastV2(pydantic.BaseModel): | ||
temperature: int | ||
f_or_c: Literal["F", "C"] | ||
forecast: str | ||
|
||
|
||
class ForecastV1(V1BaseModel): | ||
temperature: int | ||
f_or_c: Literal["F", "C"] | ||
forecast: str | ||
|
||
|
||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) | ||
def test_pydantic_parser_chaining( | ||
pydantic_object: TBaseModel, | ||
) -> None: | ||
prompt = PromptTemplate( | ||
template="""{{ | ||
"temperature": 20, | ||
"f_or_c": "C", | ||
"forecast": "Sunny" | ||
}}""", | ||
input_variables=[], | ||
) | ||
|
||
model = ParrotFakeChatModel() | ||
|
||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore | ||
chain = prompt | model | parser | ||
|
||
res = chain.invoke({}) | ||
assert type(res) == pydantic_object | ||
assert res.f_or_c == "C" | ||
assert res.temperature == 20 | ||
assert res.forecast == "Sunny" | ||
|
||
|
||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) | ||
def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None: | ||
bad_prompt = PromptTemplate( | ||
template="""{{ | ||
"temperature": "oof", | ||
"f_or_c": 1, | ||
"forecast": "Sunny" | ||
}}""", | ||
input_variables=[], | ||
) | ||
|
||
model = ParrotFakeChatModel() | ||
|
||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore | ||
chain = bad_prompt | model | parser | ||
with pytest.raises(OutputParserException): | ||
chain.invoke({}) |