Skip to content

Commit

Permalink
community[patch]: sambanova llm integration improvement (#23137)
Browse files Browse the repository at this point in the history
- **Description:** sambanova sambaverse integration improvement: removed
input parsing that was changing raw user input, and was making to use
process prompt parameter as true mandatory
  • Loading branch information
jhpiedrahitao authored Jun 19, 2024
1 parent e162893 commit b3e53ff
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 40 deletions.
42 changes: 38 additions & 4 deletions docs/docs/integrations/llms/sambanova.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"repetition_penalty\": 1.0,\n",
Expand Down Expand Up @@ -116,7 +115,6 @@
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"process_prompt\": True,\n",
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
" # \"repetition_penalty\": 1.0,\n",
Expand Down Expand Up @@ -177,14 +175,16 @@
"import os\n",
"\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"# sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
"sambastudio_base_uri = (\n",
" \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
")\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n",
"# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
"# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
Expand Down Expand Up @@ -247,6 +247,40 @@
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also call a CoE endpoint expert model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Using a CoE endpoint\n",
"\n",
"from langchain_community.llms.sambanova import SambaStudio\n",
"\n",
"llm = SambaStudio(\n",
" streaming=False,\n",
" model_kwargs={\n",
" \"do_sample\": True,\n",
" \"max_tokens_to_generate\": 1000,\n",
" \"temperature\": 0.01,\n",
" \"select_expert\": \"Meta-Llama-3-8B-Instruct\",\n",
" # \"repetition_penalty\": 1.0,\n",
" # \"top_k\": 50,\n",
" # \"top_logprobs\": 0,\n",
" # \"top_p\": 1.0\n",
" },\n",
")\n",
"\n",
"print(llm.invoke(\"Why should I use open source models?\"))"
]
}
],
"metadata": {
Expand Down
50 changes: 14 additions & 36 deletions libs/community/langchain_community/llms/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _process_response(response: requests.Response) -> Dict:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
Expand Down Expand Up @@ -87,7 +87,7 @@ def _get_full_url(self) -> str:
"""
Return the full API URL for a given path.
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}{self.API_BASE_PATH}"

Expand All @@ -108,23 +108,12 @@ def nlp_predict(
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
:type: dict
"""
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
response = self.http_session.post(
self._get_full_url(),
headers={
Expand Down Expand Up @@ -152,23 +141,12 @@ def nlp_predict_stream(
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
:type: dict
"""
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
# Streaming output
response = self.http_session.post(
self._get_full_url(),
Expand Down Expand Up @@ -522,7 +500,7 @@ def _process_response(self, response: requests.Response) -> Dict:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
Expand Down Expand Up @@ -581,7 +559,7 @@ def _get_full_url(self, path: str) -> str:
:param str path: the sub-path
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}/{self.api_base_uri}/{path}"

Expand All @@ -603,7 +581,7 @@ def nlp_predict(
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if isinstance(input, str):
input = [input]
Expand Down Expand Up @@ -645,7 +623,7 @@ def nlp_predict_stream(
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if "nlp" in self.api_base_uri:
if isinstance(input, str):
Expand Down

0 comments on commit b3e53ff

Please sign in to comment.