Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: fix yuan2 errors in LLMs #19004

Merged
merged 9 commits into from
Mar 28, 2024
19 changes: 16 additions & 3 deletions libs/community/langchain_community/llms/yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Yuan2(LLM):
top_p: Optional[float] = 0.9
"""The top-p value to use for sampling."""

top_k: Optional[int] = 40
top_k: Optional[int] = 0
"""The top-k value to use for sampling."""

do_sample: bool = False
Expand Down Expand Up @@ -70,6 +70,17 @@ class Yuan2(LLM):
use_history: bool = False
"""Whether to use history or not"""

def __init__(self, **kwargs: Any) -> None:
"""Initialize the Yuan2 class."""
super().__init__(**kwargs)

if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
logger.warning(
"top_p and top_k cannot be set simultaneously. "
"set top_k to 0 instead..."
)
self.top_k = 0

@property
def _llm_type(self) -> str:
return "Yuan2.0"
Expand All @@ -86,12 +97,13 @@ def _model_param_names() -> Set[str]:

def _default_params(self) -> Dict[str, Any]:
return {
"do_sample": self.do_sample,
"infer_api": self.infer_api,
"max_tokens": self.max_tokens,
"repeat_penalty": self.repeat_penalty,
"temp": self.temp,
"top_k": self.top_k,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_history": self.use_history,
}

Expand Down Expand Up @@ -135,6 +147,7 @@ def _call(
input = prompt

headers = {"Content-Type": "application/json"}

data = json.dumps(
{
"ques_list": [{"id": "000", "ques": input}],
Expand Down Expand Up @@ -164,7 +177,7 @@ def _call(
if resp["errCode"] != "0":
raise ValueError(
f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['errMessage']}]"
f"error message: [{resp['exceptionMsg']}]"
)

if "resData" in resp:
Expand Down
2 changes: 0 additions & 2 deletions libs/community/tests/integration_tests/llms/test_yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None:
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm("写一段快速排序算法。")
Expand All @@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None:
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm.generate(["who are you?"])
Expand Down
Loading