forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: add support for Moonshot llm and chat model (langch…
- Loading branch information
Showing
5 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"---\n", | ||
"sidebar_label: Moonshot\n", | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"source": [ | ||
"# MoonshotChat\n", | ||
"\n", | ||
"[Moonshot](https://platform.moonshot.cn/) is a Chinese startup that provides LLM service for companies and individuals.\n", | ||
"\n", | ||
"This example goes over how to use LangChain to interact with Moonshot Inference for Chat." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"# Generate your api key from: https://platform.moonshot.cn/console/api-keys\n", | ||
"os.environ[\"MOONSHOT_API_KEY\"] = \"MOONSHOT_API_KEY\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.chat_models.moonshot import MoonshotChat\n", | ||
"from langchain_core.messages import HumanMessage, SystemMessage" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chat = MoonshotChat()\n", | ||
"# or use a specific model\n", | ||
"# Available models: https://platform.moonshot.cn/docs\n", | ||
"# chat = MoonshotChat(model=\"moonshot-v1-128k\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"messages = [\n", | ||
" SystemMessage(\n", | ||
" content=\"You are a helpful assistant that translates English to French.\"\n", | ||
" ),\n", | ||
" HumanMessage(\n", | ||
" content=\"Translate this sentence from English to French. I love programming.\"\n", | ||
" ),\n", | ||
"]\n", | ||
"\n", | ||
"chat.invoke(messages)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MoonshotChat\n", | ||
"\n", | ||
"[Moonshot](https://platform.moonshot.cn/) is a Chinese startup that provides LLM service for companies and individuals.\n", | ||
"\n", | ||
"This example goes over how to use LangChain to interact with Moonshot." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.llms.moonshot import Moonshot" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"# Generate your api key from: https://platform.moonshot.cn/console/api-keys\n", | ||
"os.environ[\"MOONSHOT_API_KEY\"] = \"MOONSHOT_API_KEY\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"llm = Moonshot()\n", | ||
"# or use a specific model\n", | ||
"# Available models: https://platform.moonshot.cn/docs\n", | ||
"# llm = Moonshot(model=\"moonshot-v1-128k\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": true | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Prompt the model\n", | ||
"llm.invoke(\"What is the difference between panda and bear?\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
58 changes: 58 additions & 0 deletions
58
libs/community/langchain_community/chat_models/moonshot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Wrapper around Moonshot chat models.""" | ||
from typing import Dict | ||
|
||
from langchain_core.pydantic_v1 import root_validator | ||
from langchain_core.utils import get_from_dict_or_env | ||
|
||
from langchain_community.chat_models import ChatOpenAI | ||
from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon | ||
|
||
|
||
class MoonshotChat(MoonshotCommon, ChatOpenAI): | ||
"""Wrapper around Moonshot large language models. | ||
To use, you should have the ``openai`` python package installed, and the | ||
environment variable ``MOONSHOT_API_KEY`` set with your API key. | ||
(Moonshot's chat API is compatible with OpenAI's SDK.) | ||
Referenced from https://platform.moonshot.cn/docs | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.chat_models.moonshot import MoonshotChat | ||
moonshot = MoonshotChat(model="moonshot-v1-8k") | ||
""" | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that the environment is set up correctly.""" | ||
values["moonshot_api_key"] = get_from_dict_or_env( | ||
values, "moonshot_api_key", "MOONSHOT_API_KEY" | ||
) | ||
|
||
try: | ||
import openai | ||
|
||
except ImportError: | ||
raise ImportError( | ||
"Could not import openai python package. " | ||
"Please install it with `pip install openai`." | ||
) | ||
|
||
client_params = { | ||
"api_key": values["moonshot_api_key"], | ||
"base_url": values["base_url"] | ||
if "base_url" in values | ||
else MOONSHOT_SERVICE_URL_BASE, | ||
} | ||
|
||
if not values.get("client"): | ||
values["client"] = openai.OpenAI(**client_params).chat.completions | ||
if not values.get("async_client"): | ||
values["async_client"] = openai.AsyncOpenAI( | ||
**client_params | ||
).chat.completions | ||
|
||
return values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import requests | ||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.language_models import LLM | ||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator | ||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env | ||
|
||
from langchain_community.llms.utils import enforce_stop_tokens | ||
|
||
MOONSHOT_SERVICE_URL_BASE = "https://api.moonshot.cn/v1" | ||
|
||
|
||
class _MoonshotClient(BaseModel): | ||
"""An API client that talks to the Moonshot server.""" | ||
|
||
api_key: SecretStr | ||
"""The API key to use for authentication.""" | ||
base_url: str = MOONSHOT_SERVICE_URL_BASE | ||
|
||
def completion(self, request: Any) -> Any: | ||
headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"} | ||
response = requests.post( | ||
f"{self.base_url}/chat/completions", | ||
headers=headers, | ||
json=request, | ||
) | ||
if not response.ok: | ||
raise ValueError(f"HTTP {response.status_code} error: {response.text}") | ||
return response.json()["choices"][0]["message"]["content"] | ||
|
||
|
||
class MoonshotCommon(BaseModel): | ||
_client: _MoonshotClient | ||
base_url: str = MOONSHOT_SERVICE_URL_BASE | ||
moonshot_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") | ||
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys""" | ||
model_name: str = Field(default="moonshot-v1-8k", alias="model") | ||
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing""" | ||
max_tokens = 1024 | ||
"""Maximum number of tokens to generate.""" | ||
temperature = 0.3 | ||
"""Temperature parameter (higher values make the model more creative).""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
allow_population_by_field_name = True | ||
|
||
@property | ||
def lc_secrets(self) -> dict: | ||
"""A map of constructor argument names to secret ids. | ||
For example, | ||
{"moonshot_api_key": "MOONSHOT_API_KEY"} | ||
""" | ||
return {"moonshot_api_key": "MOONSHOT_API_KEY"} | ||
|
||
@property | ||
def _default_params(self) -> Dict[str, Any]: | ||
"""Get the default parameters for calling OpenAI API.""" | ||
return { | ||
"model": self.model_name, | ||
"max_tokens": self.max_tokens, | ||
"temperature": self.temperature, | ||
} | ||
|
||
@property | ||
def _invocation_params(self) -> Dict[str, Any]: | ||
return {**{"model": self.model_name}, **self._default_params} | ||
|
||
@root_validator(pre=True) | ||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Build extra parameters. | ||
Override the superclass method, prevent the model parameter from being | ||
overridden. | ||
""" | ||
return values | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
values["moonshot_api_key"] = convert_to_secret_str( | ||
get_from_dict_or_env(values, "moonshot_api_key", "MOONSHOT_API_KEY") | ||
) | ||
|
||
values["_client"] = _MoonshotClient( | ||
api_key=values["moonshot_api_key"], | ||
base_url=values["base_url"] | ||
if "base_url" in values | ||
else MOONSHOT_SERVICE_URL_BASE, | ||
) | ||
return values | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of llm.""" | ||
return "moonshot" | ||
|
||
|
||
class Moonshot(MoonshotCommon, LLM): | ||
"""Moonshot large language models. | ||
To use, you should have the environment variable ``MOONSHOT_API_KEY`` set with your | ||
API key. Referenced from https://platform.moonshot.cn/docs | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.llms.moonshot import Moonshot | ||
moonshot = Moonshot(model="moonshot-v1-8k") | ||
""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
allow_population_by_field_name = True | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
request = self._invocation_params | ||
request["messages"] = [{"role": "user", "content": prompt}] | ||
request.update(kwargs) | ||
text = self._client.completion(request) | ||
if stop is not None: | ||
# This is required since the stop tokens | ||
# are not enforced by the model parameters | ||
text = enforce_stop_tokens(text, stop) | ||
|
||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from langchain_community.llms.moonshot import Moonshot | ||
|
||
os.environ["MOONSHOT_API_KEY"] = "key" | ||
|
||
|
||
@pytest.mark.requires("openai") | ||
def test_moonshot_model_param() -> None: | ||
llm = Moonshot(model="foo") | ||
assert llm.model_name == "foo" | ||
llm = Moonshot(model_name="bar") | ||
assert llm.model_name == "bar" |