Skip to content

Commit

Permalink
community[patch]: Update root_validators to use pre=True or pre=False (
Browse files Browse the repository at this point in the history
…#23731)

Update root_validators in preparation for pydantic 2 migration.
  • Loading branch information
eyurtsev authored Jul 1, 2024
1 parent 6019147 commit 5d2262a
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def execute_agent(agent, tools, input):
as_agent: bool = False
"""Use as a LangChain agent, compatible with the AgentExecutor."""

@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_async_client(cls, values: dict) -> dict:
if values["async_client"] is None:
import openai
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chains/llm_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def output_keys(self) -> List[str]:
"""
return [self.output_key]

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_available_models(

return {model["id"] for model in models_response.json()["data"]}

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["anyscale_api_key"] = convert_to_secret_str(
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/coze.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Config:

allow_population_by_field_name = True

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["coze_api_base"] = get_from_dict_or_env(
values,
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/dappier.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Config:

extra = Extra.forbid

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["dappier_api_key"] = convert_to_secret_str(
Expand Down
7 changes: 5 additions & 2 deletions libs/community/langchain_community/chat_models/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ async def _completion_with_retry(**kwargs: Any) -> Any:

return await _completion_with_retry(**kwargs)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def init_defaults(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
# For compatibility with LiteLLM
api_key = get_from_dict_or_env(
Expand All @@ -294,7 +294,10 @@ def validate_environment(cls, values: Dict) -> Dict:
"DEEPINFRA_API_TOKEN",
default=api_key,
)
return values

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ErnieBotChat(BaseChatModel):

_lock = threading.Lock()

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["ernie_api_base"] = get_from_dict_or_env(
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, **kwargs: Any):
else self.tokenizer
)

@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_llm(cls, values: dict) -> dict:
if not isinstance(
values["llm"],
Expand Down

0 comments on commit 5d2262a

Please sign in to comment.