-
Notifications
You must be signed in to change notification settings - Fork 16.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Dappier chat model integration (#19370)
**Description:** This PR adds [Dappier](https://dappier.com/) for the chat model. It supports generate, async generate, and batch functionalities. We added unit and integration tests as well as a notebook with more details about our chat model. **Dependencies:** No extra dependencies are needed.
- Loading branch information
1 parent
64e1df3
commit 743f888
Showing
4 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Dappier AI" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Dappier: Powering AI with Dynamic, Real-Time Data Models**\n", | ||
"\n", | ||
"Dappier offers a cutting-edge platform that grants developers immediate access to a wide array of real-time data models spanning news, entertainment, finance, market data, weather, and beyond. With our pre-trained data models, you can supercharge your AI applications, ensuring they deliver precise, up-to-date responses and minimize inaccuracies.\n", | ||
"\n", | ||
"Dappier data models help you build next-gen LLM apps with trusted, up-to-date content from the world's leading brands. Unleash your creativity and enhance any GPT App or AI workflow with actionable, proprietary, data through a simple API. Augment your AI with proprietary data from trusted sources is the best way to ensure factual, up-to-date, responses with fewer hallucinations no matter the question.\n", | ||
"\n", | ||
"For Developers, By Developers\n", | ||
"Designed with developers in mind, Dappier simplifies the journey from data integration to monetization, providing clear, straightforward paths to deploy and earn from your AI models. Experience the future of monetization infrastructure for the new internet at **https://dappier.com/**." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This example goes over how to use LangChain to interact with Dappier AI models\n", | ||
"\n", | ||
"-----------------------------------------------------------------------------------" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use one of our Dappier AI Data Models, you will need an API key. Please visit Dappier Platform (https://platform.dappier.com/) to log in and create an API key in your profile.\n", | ||
"\n", | ||
"\n", | ||
"You can find more details on the API reference : https://docs.dappier.com/introduction" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To work with our Dappier Chat Model you can pass the key directly through the parameter named dappier_api_key when initiating the class\n", | ||
"or set as an environment variable.\n", | ||
"\n", | ||
"```bash\n", | ||
"export DAPPIER_API_KEY=\"...\"\n", | ||
"```\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.chat_models.dappier import ChatDappierAI\n", | ||
"from langchain_core.messages import HumanMessage" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chat = ChatDappierAI(\n", | ||
" dappier_endpoint=\"https://api.dappier.com/app/datamodelconversation\",\n", | ||
" dappier_model=\"dm_01hpsxyfm2fwdt2zet9cg6fdxt\",\n", | ||
" dappier_api_key=\"...\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AIMessage(content='Hey there! The Kansas City Chiefs won Super Bowl LVIII in 2024. They beat the San Francisco 49ers in overtime with a final score of 25-22. It was quite the game! 🏈')" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"messages = [HumanMessage(content=\"Who won the super bowl in 2024?\")]\n", | ||
"chat.invoke(messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AIMessage(content='The Kansas City Chiefs won Super Bowl LVIII in 2024! 🏈')" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"await chat.ainvoke(messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
161 changes: 161 additions & 0 deletions
161
libs/community/langchain_community/chat_models/dappier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from aiohttp import ClientSession | ||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain_core.language_models.chat_models import ( | ||
BaseChatModel, | ||
) | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
) | ||
from langchain_core.outputs import ChatGeneration, ChatResult | ||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator | ||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env | ||
|
||
from langchain_community.utilities.requests import Requests | ||
|
||
|
||
def _format_dappier_messages( | ||
messages: List[BaseMessage], | ||
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]: | ||
formatted_messages = [] | ||
|
||
for message in messages: | ||
if message.type == "human": | ||
formatted_messages.append({"role": "user", "content": message.content}) | ||
elif message.type == "system": | ||
formatted_messages.append({"role": "system", "content": message.content}) | ||
|
||
return formatted_messages | ||
|
||
|
||
class ChatDappierAI(BaseChatModel): | ||
"""`Dappier` chat large language models. | ||
`Dappier` is a platform enabling access to diverse, real-time data models. | ||
Enhance your AI applications with Dappier's pre-trained, LLM-ready data models | ||
and ensure accurate, current responses with reduced inaccuracies. | ||
To use one of our Dappier AI Data Models, you will need an API key. | ||
Please visit Dappier Platform (https://platform.dappier.com/) to log in | ||
and create an API key in your profile. | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.chat_models import ChatDappierAI | ||
from langchain_core.messages import HumanMessage | ||
# Initialize `ChatDappierAI` with the desired configuration | ||
chat = ChatDappierAI( | ||
dappier_endpoint="https://api.dappier.com/app/datamodel/dm_01hpsxyfm2fwdt2zet9cg6fdxt", | ||
dappier_api_key="<YOUR_KEY>") | ||
# Create a list of messages to interact with the model | ||
messages = [HumanMessage(content="hello")] | ||
# Invoke the model with the provided messages | ||
chat.invoke(messages) | ||
you can find more details here : https://docs.dappier.com/introduction""" | ||
|
||
dappier_endpoint: str = "https://api.dappier.com/app/datamodelconversation" | ||
|
||
dappier_model: str = "dm_01hpsxyfm2fwdt2zet9cg6fdxt" | ||
|
||
dappier_api_key: Optional[SecretStr] = Field(None, description="Dappier API Token") | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key exists in environment.""" | ||
values["dappier_api_key"] = convert_to_secret_str( | ||
get_from_dict_or_env(values, "dappier_api_key", "DAPPIER_API_KEY") | ||
) | ||
return values | ||
|
||
@staticmethod | ||
def get_user_agent() -> str: | ||
from langchain_community import __version__ | ||
|
||
return f"langchain/{__version__}" | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of chat model.""" | ||
return "dappier-realtimesearch-chat" | ||
|
||
@property | ||
def _api_key(self) -> str: | ||
if self.dappier_api_key: | ||
return self.dappier_api_key.get_secret_value() | ||
return "" | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
url = f"{self.dappier_endpoint}" | ||
headers = { | ||
"Authorization": f"Bearer {self._api_key}", | ||
"User-Agent": self.get_user_agent(), | ||
} | ||
user_query = _format_dappier_messages(messages=messages) | ||
payload: Dict[str, Any] = { | ||
"model": self.dappier_model, | ||
"conversation": user_query, | ||
} | ||
|
||
request = Requests(headers=headers) | ||
response = request.post(url=url, data=payload) | ||
response.raise_for_status() | ||
|
||
data = response.json() | ||
|
||
message_response = data["message"] | ||
|
||
return ChatResult( | ||
generations=[ChatGeneration(message=AIMessage(content=message_response))] | ||
) | ||
|
||
async def _agenerate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
url = f"{self.dappier_endpoint}" | ||
headers = { | ||
"Authorization": f"Bearer {self._api_key}", | ||
"User-Agent": self.get_user_agent(), | ||
} | ||
user_query = _format_dappier_messages(messages=messages) | ||
payload: Dict[str, Any] = { | ||
"model": self.dappier_model, | ||
"conversation": user_query, | ||
} | ||
|
||
async with ClientSession() as session: | ||
async with session.post(url, json=payload, headers=headers) as response: | ||
response.raise_for_status() | ||
data = await response.json() | ||
message_response = data["message"] | ||
|
||
return ChatResult( | ||
generations=[ | ||
ChatGeneration(message=AIMessage(content=message_response)) | ||
] | ||
) |
58 changes: 58 additions & 0 deletions
58
libs/community/tests/integration_tests/chat_models/test_dappier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import List | ||
|
||
import pytest | ||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | ||
from langchain_core.outputs import ChatGeneration, LLMResult | ||
|
||
from langchain_community.chat_models.dappier import ( | ||
ChatDappierAI, | ||
) | ||
|
||
|
||
@pytest.mark.scheduled | ||
def test_dappier_chat() -> None: | ||
"""Test ChatDappierAI wrapper.""" | ||
chat = ChatDappierAI( | ||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation", | ||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt", | ||
) | ||
message = HumanMessage(content="Who are you ?") | ||
response = chat([message]) | ||
assert isinstance(response, AIMessage) | ||
assert isinstance(response.content, str) | ||
|
||
|
||
@pytest.mark.scheduled | ||
def test_dappier_generate() -> None: | ||
"""Test generate method of Dappier AI.""" | ||
chat = ChatDappierAI( | ||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation", | ||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt", | ||
) | ||
chat_messages: List[List[BaseMessage]] = [ | ||
[HumanMessage(content="Who won the last super bowl?")], | ||
] | ||
messages_copy = [messages.copy() for messages in chat_messages] | ||
result: LLMResult = chat.generate(chat_messages) | ||
assert isinstance(result, LLMResult) | ||
for response in result.generations[0]: | ||
assert isinstance(response, ChatGeneration) | ||
assert isinstance(response.text, str) | ||
assert response.text == response.message.content | ||
assert chat_messages == messages_copy | ||
|
||
|
||
@pytest.mark.scheduled | ||
async def test_dappier_agenerate() -> None: | ||
"""Test async generation.""" | ||
chat = ChatDappierAI( | ||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation", | ||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt", | ||
) | ||
message = HumanMessage(content="Who won the last super bowl?") | ||
result: LLMResult = await chat.agenerate([[message], [message]]) | ||
assert isinstance(result, LLMResult) | ||
for response in result.generations[0]: | ||
assert isinstance(response, ChatGeneration) | ||
assert isinstance(response.text, str) | ||
assert response.text == response.message.content |
Oops, something went wrong.