diff --git a/docs/docs/integrations/chat/premai.ipynb b/docs/docs/integrations/chat/premai.ipynb new file mode 100644 index 0000000000000..13a2ece273304 --- /dev/null +++ b/docs/docs/integrations/chat/premai.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: PremAI\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatPremAI\n", + "\n", + ">[PremAI](https://app.premai.io) is a unified platform that lets you build powerful production-ready GenAI-powered applications with the least effort so that you can focus more on user experience and overall growth. \n", + "\n", + "\n", + "This example goes over how to use LangChain to interact with `ChatPremAI`. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation and setup\n", + "\n", + "We start by installing langchain and premai-sdk. You can type the following command to install:\n", + "\n", + "```bash\n", + "pip install premai langchain\n", + "```\n", + "\n", + "Before proceeding further, please make sure that you have made an account on PremAI and already started a project. If not, then here's how you can start for free:\n", + "\n", + "1. Sign in to [PremAI](https://app.premai.io/accounts/login/), if you are coming for the first time and create your API key [here](https://app.premai.io/api_keys/).\n", + "\n", + "2. Go to [app.premai.io](https://app.premai.io) and this will take you to the project's dashboard. \n", + "\n", + "3. Create a project and this will generate a project-id (written as ID). This ID will help you to interact with your deployed application. \n", + "\n", + "4. Head over to LaunchPad (the one with 🚀 icon). And there deploy your model of choice. Your default model will be `gpt-4`. You can also set and fix different generation parameters (like max-tokens, temperature, etc) and also pre-set your system prompt. \n", + "\n", + "Congratulations on creating your first deployed application on PremAI 🎉 Now we can use langchain to interact with our application. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatPremAI\n", + "from langchain_core.messages import HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup ChatPremAI instance in LangChain \n", + "\n", + "Once we import our required modules, let's set up our client. For now, let's assume that our `project_id` is 8. But make sure you use your project-id, otherwise, it will throw an error.\n", + "\n", + "To use langchain with prem, you do not need to pass any model name or set any parameters with our chat client. All of those will use the default model name and parameters of the LaunchPad model. \n", + "\n", + "`NOTE:` If you change the `model_name` or any other parameter like `temperature` while setting the client, it will override existing default configurations. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# First step is to set up the env variable.\n", + "# you can also pass the API key while instantiating the model but this\n", + "# comes under a best practices to set it as env variable.\n", + "\n", + "if os.environ.get(\"PREMAI_API_KEY\") is None:\n", + " os.environ[\"PREMAI_API_KEY\"] = getpass.getpass(\"PremAI API Key:\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# By default it will use the model which was deployed through the platform\n", + "# in my case it will is \"claude-3-haiku\"\n", + "\n", + "chat = ChatPremAI(project_id=8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calling the Model\n", + "\n", + "Now you are all set. We can now start by interacting with our application. `ChatPremAI` supports two methods `invoke` (which is the same as `generate`) and `stream`. \n", + "\n", + "The first one will give us a static result. Whereas the second one will stream tokens one by one. Here's how you can generate chat-like completions. \n", + "\n", + "### Generation" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I am an artificial intelligence created by Anthropic. I'm here to help with a wide variety of tasks, from research and analysis to creative projects and open-ended conversation. I have general knowledge and capabilities, but I'm not a real person - I'm an AI assistant. Please let me know if you have any other questions!\n" + ] + } + ], + "source": [ + "human_message = HumanMessage(content=\"Who are you?\")\n", + "\n", + "response = chat.invoke([human_message])\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Above looks interesting right? I set my default lanchpad system-prompt as: `Always sound like a pirate` You can also, override the default system prompt if you need to. Here's how you can do it. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"I am an artificial intelligence created by Anthropic. My purpose is to assist and converse with humans in a friendly and helpful way. I have a broad knowledge base that I can use to provide information, answer questions, and engage in discussions on a wide range of topics. Please let me know if you have any other questions - I'm here to help!\")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "system_message = SystemMessage(content=\"You are a friendly assistant.\")\n", + "human_message = HumanMessage(content=\"Who are you?\")\n", + "\n", + "chat.invoke([system_message, human_message])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also change generation parameters while calling the model. Here's how you can do that" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='I am an artificial intelligence created by Anthropic')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat.invoke([system_message, human_message], temperature=0.7, max_tokens=10, top_p=0.95)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Important notes:\n", + "\n", + "Before proceeding further, please note that the current version of ChatPrem does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop) are not supported. \n", + "\n", + "We will provide support for those two above parameters in sooner versions. \n", + "\n", + "### Streaming\n", + "\n", + "And finally, here's how you do token streaming for dynamic chat like applications. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hello! As an AI language model, I don't have feelings or a physical state, but I'm functioning properly and ready to assist you with any questions or tasks you might have. How can I help you today?" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "for chunk in chat.stream(\"hello how are you\"):\n", + " sys.stdout.write(chunk.content)\n", + " sys.stdout.flush()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similar to above, if you want to override the system-prompt and the generation parameters, here's how you can do it. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hello! As an AI language model, I don't have feelings or a physical form, but I'm functioning properly and ready to assist you. How can I help you today?" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "# For some experimental reasons if you want to override the system prompt then you\n", + "# can pass that here too. However it is not recommended to override system prompt\n", + "# of an already deployed model.\n", + "\n", + "for chunk in chat.stream(\n", + " \"hello how are you\",\n", + " system_prompt=\"act like a dog\",\n", + " temperature=0.7,\n", + " max_tokens=200,\n", + "):\n", + " sys.stdout.write(chunk.content)\n", + " sys.stdout.flush()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/integrations/providers/premai.md b/docs/docs/integrations/providers/premai.md new file mode 100644 index 0000000000000..7ec2ca5fa84a7 --- /dev/null +++ b/docs/docs/integrations/providers/premai.md @@ -0,0 +1,181 @@ +# PremAI + +>[PremAI](https://app.premai.io) is a unified platform that lets you build powerful production-ready GenAI-powered applications with the least effort so that you can focus more on user experience and overall growth. + + +## ChatPremAI + +This example goes over how to use LangChain to interact with different chat models with `ChatPremAI` + +### Installation and setup + +We start by installing langchain and premai-sdk. You can type the following command to install: + +```bash +pip install premai langchain +``` + +Before proceeding further, please make sure that you have made an account on PremAI and already started a project. If not, then here's how you can start for free: + +1. Sign in to [PremAI](https://app.premai.io/accounts/login/), if you are coming for the first time and create your API key [here](https://app.premai.io/api_keys/). + +2. Go to [app.premai.io](https://app.premai.io) and this will take you to the project's dashboard. + +3. Create a project and this will generate a project-id (written as ID). This ID will help you to interact with your deployed application. + +4. Head over to LaunchPad (the one with 🚀 icon). And there deploy your model of choice. Your default model will be `gpt-4`. You can also set and fix different generation parameters (like max-tokens, temperature, etc) and also pre-set your system prompt. + +Congratulations on creating your first deployed application on PremAI 🎉 Now we can use langchain to interact with our application. + +```python +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_community.chat_models import ChatPremAI +``` + +### Setup ChatPrem instance in LangChain + +Once we import our required modules, let's set up our client. For now, let's assume that our `project_id` is 8. But make sure you use your project-id, otherwise, it will throw an error. + +To use langchain with prem, you do not need to pass any model name or set any parameters with our chat client. All of those will use the default model name and parameters of the LaunchPad model. + +`NOTE:` If you change the `model_name` or any other parameter like `temperature` while setting the client, it will override existing default configurations. + +```python +import os +import getpass + +if "PREMAI_API_KEY" not in os.environ: + os.environ["PREMAI_API_KEY"] = getpass.getpass("PremAI API Key:") + +chat = ChatPremAI(project_id=8) +``` + +### Calling the Model + +Now you are all set. We can now start by interacting with our application. `ChatPremAI` supports two methods `invoke` (which is the same as `generate`) and `stream`. + +The first one will give us a static result. Whereas the second one will stream tokens one by one. Here's how you can generate chat-like completions. + +### Generation + +```python +human_message = HumanMessage(content="Who are you?") + +chat.invoke([human_message]) +``` + +The above looks interesting, right? I set my default launchpad system-prompt as: `Always sound like a pirate` You can also, override the default system prompt if you need to. Here's how you can do it. + +```python +system_message = SystemMessage(content="You are a friendly assistant.") +human_message = HumanMessage(content="Who are you?") + +chat.invoke([system_message, human_message]) +``` + +You can also change generation parameters while calling the model. Here's how you can do that: + +```python +chat.invoke( + [system_message, human_message], + temperature = 0.7, max_tokens = 20, top_p = 0.95 +) +``` + + +### Important notes: + +Before proceeding further, please note that the current version of ChatPrem does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop) are not supported. + +We will provide support for those two above parameters in later versions. + +### Streaming + +And finally, here's how you do token streaming for dynamic chat-like applications. + +```python +import sys + +for chunk in chat.stream("hello how are you"): + sys.stdout.write(chunk.content) + sys.stdout.flush() +``` + +Similar to above, if you want to override the system-prompt and the generation parameters, here's how you can do it. + +```python +import sys + +for chunk in chat.stream( + "hello how are you", + system_prompt = "You are an helpful assistant", temperature = 0.7, max_tokens = 20 +): + sys.stdout.write(chunk.content) + sys.stdout.flush() +``` + +## Embedding + +In this section, we are going to discuss how we can get access to different embedding models using `PremEmbeddings`. Let's start by doing some imports and defining our embedding object + +```python +from langchain_community.embeddings import PremEmbeddings +``` + +Once we import our required modules, let's set up our client. For now, let's assume that our `project_id` is 8. But make sure you use your project-id, otherwise, it will throw an error. + + +```python + +import os +import getpass + +if os.environ.get("PREMAI_API_KEY") is None: + os.environ["PREMAI_API_KEY"] = getpass.getpass("PremAI API Key:") + +# Define a model as a required parameter here since there is no default embedding model + +model = "text-embedding-3-large" +embedder = PremEmbeddings(project_id=8, model=model) +``` + +We have defined our embedding model. We support a lot of embedding models. Here is a table that shows the number of embedding models we support. + + +| Provider | Slug | Context Tokens | +|-------------|------------------------------------------|----------------| +| cohere | embed-english-v3.0 | N/A | +| openai | text-embedding-3-small | 8191 | +| openai | text-embedding-3-large | 8191 | +| openai | text-embedding-ada-002 | 8191 | +| replicate | replicate/all-mpnet-base-v2 | N/A | +| together | togethercomputer/Llama-2-7B-32K-Instruct | N/A | +| mistralai | mistral-embed | 4096 | + +To change the model, you simply need to copy the `slug` and access your embedding model. Now let's start using our embedding model with a single query followed by multiple queries (which is also called as a document) + +```python +query = "Hello, this is a test query" +query_result = embedder.embed_query(query) + +# Let's print the first five elements of the query embedding vector + +print(query_result[:5]) +``` + +Finally, let's embed a document + +```python +documents = [ + "This is document1", + "This is document2", + "This is document3" +] + +doc_result = embedder.embed_documents(documents) + +# Similar to the previous result, let's print the first five element +# of the first document vector + +print(doc_result[0][:5]) +``` \ No newline at end of file diff --git a/docs/docs/integrations/text_embedding/premai.ipynb b/docs/docs/integrations/text_embedding/premai.ipynb new file mode 100644 index 0000000000000..d8bf54fd43f53 --- /dev/null +++ b/docs/docs/integrations/text_embedding/premai.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PremAI\n", + "\n", + ">[PremAI](https://app.premai.io) is an unified platform that let's you build powerful production-ready GenAI powered applications with least effort, so that you can focus more on user experience and overall growth. In this section we are going to dicuss how we can get access to different embedding model using `PremAIEmbeddings`\n", + "\n", + "## Installation and Setup\n", + "\n", + "We start by installing langchain and premai-sdk. You can type the following command to install:\n", + "\n", + "```bash\n", + "pip install premai langchain\n", + "```\n", + "\n", + "Before proceeding further, please make sure that you have made an account on Prem and already started a project. If not, then here's how you can start for free:\n", + "\n", + "1. Sign in to [PremAI](https://app.premai.io/accounts/login/), if you are coming for the first time and create your API key [here](https://app.premai.io/api_keys/).\n", + "\n", + "2. Go to [app.premai.io](https://app.premai.io) and this will take you to the project's dashboard. \n", + "\n", + "3. Create a project and this will generate a project-id (written as ID). This ID will help you to interact with your deployed application. \n", + "\n", + "Congratulations on creating your first deployed application on Prem 🎉 Now we can use langchain to interact with our application. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's start by doing some imports and define our embedding object\n", + "\n", + "from langchain_community.embeddings import PremAIEmbeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we imported our required modules, let's setup our client. For now let's assume that our `project_id` is 8. But make sure you use your project-id, otherwise it will throw error.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if os.environ.get(\"PREMAI_API_KEY\") is None:\n", + " os.environ[\"PREMAI_API_KEY\"] = getpass.getpass(\"PremAI API Key:\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = \"text-embedding-3-large\"\n", + "embedder = PremAIEmbeddings(project_id=8, model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have defined our embedding model. We support a lot of embedding models. Here is a table that shows the number of embedding models we support. \n", + "\n", + "\n", + "| Provider | Slug | Context Tokens |\n", + "|-------------|------------------------------------------|----------------|\n", + "| cohere | embed-english-v3.0 | N/A |\n", + "| openai | text-embedding-3-small | 8191 |\n", + "| openai | text-embedding-3-large | 8191 |\n", + "| openai | text-embedding-ada-002 | 8191 |\n", + "| replicate | replicate/all-mpnet-base-v2 | N/A |\n", + "| together | togethercomputer/Llama-2-7B-32K-Instruct | N/A |\n", + "| mistralai | mistral-embed | 4096 |\n", + "\n", + "To change the model, you simply need to copy the `slug` and access your embedding model. Now let's start using our embedding model with a single query followed by multiple queries (which is also called as a document)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-0.02129288576543331, 0.0008162345038726926, -0.004556538071483374, 0.02918623760342598, -0.02547479420900345]\n" + ] + } + ], + "source": [ + "query = \"Hello, this is a test query\"\n", + "query_result = embedder.embed_query(query)\n", + "\n", + "# Let's print the first five elements of the query embedding vector\n", + "\n", + "print(query_result[:5])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally let's embed a document" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-0.0030691148713231087, -0.045334383845329285, -0.0161729846149683, 0.04348714277148247, -0.0036920777056366205]\n" + ] + } + ], + "source": [ + "documents = [\"This is document1\", \"This is document2\", \"This is document3\"]\n", + "\n", + "doc_result = embedder.embed_documents(documents)\n", + "\n", + "# Similar to previous result, let's print the first five element\n", + "# of the first document vector\n", + "\n", + "print(doc_result[0][:5])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 76ede64dc1a1e..18f202813b599 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -64,6 +64,7 @@ "PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai", "QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint", "VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas", + "ChatPremAI": "langchain_community.chat_models.premai", } diff --git a/libs/community/langchain_community/chat_models/premai.py b/libs/community/langchain_community/chat_models/premai.py new file mode 100644 index 0000000000000..b0e9c83cf389c --- /dev/null +++ b/libs/community/langchain_community/chat_models/premai.py @@ -0,0 +1,416 @@ +"""Wrapper around Prem's Chat API.""" + +from __future__ import annotations + +import logging +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Type, + Union, +) + +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain_core.utils import get_from_dict_or_env + +if TYPE_CHECKING: + from premai.api.chat_completions.v1_chat_completions_create import ( + ChatCompletionResponseStream, + ) + from premai.models.chat_completion_response import ChatCompletionResponse + +logger = logging.getLogger(__name__) + + +class ChatPremAPIError(Exception): + """Error with the `PremAI` API.""" + + +def _truncate_at_stop_tokens( + text: str, + stop: Optional[List[str]], +) -> str: + """Truncates text at the earliest stop token found.""" + if stop is None: + return text + + for stop_token in stop: + stop_token_idx = text.find(stop_token) + if stop_token_idx != -1: + text = text[:stop_token_idx] + return text + + +def _response_to_result( + response: ChatCompletionResponse, + stop: Optional[List[str]], +) -> ChatResult: + """Converts a Prem API response into a LangChain result""" + + if not response.choices: + raise ChatPremAPIError("ChatResponse must have at least one candidate") + generations: List[ChatGeneration] = [] + for choice in response.choices: + role = choice.message.role + if role is None: + raise ChatPremAPIError(f"ChatResponse {choice} must have a role.") + + # If content is None then it will be replaced by "" + content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop) + if content is None: + raise ChatPremAPIError(f"ChatResponse must have a content: {content}") + + if role == "assistant": + generations.append( + ChatGeneration(text=content, message=AIMessage(content=content)) + ) + elif role == "user": + generations.append( + ChatGeneration(text=content, message=HumanMessage(content=content)) + ) + else: + generations.append( + ChatGeneration( + text=content, message=ChatMessage(role=role, content=content) + ) + ) + return ChatResult(generations=generations) + + +def _convert_delta_response_to_message_chunk( + response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk] +) -> Tuple[ + Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk], + Optional[str], +]: + """Converts delta response to message chunk""" + _delta = response.choices[0].delta # type: ignore + role = _delta.get("role", "") # type: ignore + content = _delta.get("content", "") # type: ignore + additional_kwargs: Dict = {} + + if role is None or role == "": + raise ChatPremAPIError("Role can not be None. Please check the response") + + finish_reasons: Optional[str] = response.choices[0].finish_reason + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content), finish_reasons + elif role == "assistant" or default_class == AIMessageChunk: + return ( + AIMessageChunk(content=content, additional_kwargs=additional_kwargs), + finish_reasons, + ) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content), finish_reasons + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role), finish_reasons + else: + return default_class(content=content), finish_reasons + + +def _messages_to_prompt_dict( + input_messages: List[BaseMessage], +) -> Tuple[Optional[str], List[Dict[str, str]]]: + """Converts a list of LangChain Messages into a simple dict + which is the message structure in Prem""" + + system_prompt: Optional[str] = None + examples_and_messages: List[Dict[str, str]] = [] + + for input_msg in input_messages: + if isinstance(input_msg, SystemMessage): + system_prompt = str(input_msg.content) + elif isinstance(input_msg, HumanMessage): + examples_and_messages.append( + {"role": "user", "content": str(input_msg.content)} + ) + elif isinstance(input_msg, AIMessage): + examples_and_messages.append( + {"role": "assistant", "content": str(input_msg.content)} + ) + else: + raise ChatPremAPIError("No such role explicitly exists") + return system_prompt, examples_and_messages + + +class ChatPremAI(BaseChatModel, BaseModel): + """Use any LLM provider with Prem and Langchain. + + To use, you will need to have an API key. You can find your existing API Key + or generate a new one here: https://app.premai.io/api_keys/ + """ + + # TODO: Need to add the default parameters through prem-sdk here + + project_id: int + """The project ID in which the experiments or deployments are carried out. + You can find all your projects here: https://app.premai.io/projects/""" + premai_api_key: Optional[SecretStr] = None + """Prem AI API Key. Get it here: https://app.premai.io/api_keys/""" + + model: Optional[str] = None + """Name of the model. This is an optional parameter. + The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad + If model name is other than default model then it will override the calls + from the model deployed from launchpad.""" + + session_id: Optional[str] = None + """The ID of the session to use. It helps to track the chat history.""" + + temperature: Optional[float] = None + """Model temperature. Value should be >= 0 and <= 1.0""" + + top_p: Optional[float] = None + """top_p adjusts the number of choices for each predicted tokens based on + cumulative probabilities. Value should be ranging between 0.0 and 1.0. + """ + + max_tokens: Optional[int] = None + """The maximum number of tokens to generate""" + + max_retries: int = 1 + """Max number of retries to call the API""" + + system_prompt: Optional[str] = "" + """Acts like a default instruction that helps the LLM act or generate + in a specific way.This is an Optional Parameter. By default the + system prompt would be using Prem's Launchpad models system prompt. + Changing the system prompt would override the default system prompt. + """ + + streaming: Optional[bool] = False + """Whether to stream the responses or not.""" + + tools: Optional[Dict[str, Any]] = None + """A list of tools the model may call. Currently, only functions are + supported as a tool""" + + frequency_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based""" + + presence_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based + on whether they appear in the text so far.""" + + logit_bias: Optional[dict] = None + """JSON object that maps tokens to an associated bias value from -100 to 100.""" + + stop: Optional[Union[str, List[str]]] = None + """Up to 4 sequences where the API will stop generating further tokens.""" + + seed: Optional[int] = None + """This feature is in Beta. If specified, our system will make a best effort + to sample deterministically.""" + + client: Any + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate that the package is installed and that the API token is valid""" + try: + from premai import Prem + except ImportError as error: + raise ImportError( + "Could not import Prem Python package." + "Please install it with: `pip install premai`" + ) from error + + try: + premai_api_key = get_from_dict_or_env( + values, "premai_api_key", "PREMAI_API_KEY" + ) + values["client"] = Prem(api_key=premai_api_key) + except Exception as error: + raise ValueError("Your API Key is incorrect. Please try again.") from error + return values + + @property + def _llm_type(self) -> str: + return "premai" + + @property + def _default_params(self) -> Dict[str, Any]: + # FIXME: n and stop is not supported, so hardcoding to current default value + return { + "model": self.model, + "system_prompt": self.system_prompt, + "top_p": self.top_p, + "temperature": self.temperature, + "logit_bias": self.logit_bias, + "max_tokens": self.max_tokens, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "seed": self.seed, + "stop": None, + } + + def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + all_kwargs = {**self._default_params, **kwargs} + for key in list(self._default_params.keys()): + if all_kwargs.get(key) is None or all_kwargs.get(key) == "": + all_kwargs.pop(key, None) + return all_kwargs + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore + + kwargs["stop"] = stop + if system_prompt is not None and system_prompt != "": + kwargs["system_prompt"] = system_prompt + + all_kwargs = self._get_all_kwargs(**kwargs) + response = chat_with_retry( + self, + project_id=self.project_id, + messages=messages_to_pass, + stream=False, + run_manager=run_manager, + **all_kwargs, + ) + + return _response_to_result(response=response, stop=stop) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) + kwargs["stop"] = stop + + if "system_prompt" not in kwargs: + if system_prompt is not None and system_prompt != "": + kwargs["system_prompt"] = system_prompt + + all_kwargs = self._get_all_kwargs(**kwargs) + + default_chunk_class = AIMessageChunk + + for streamed_response in chat_with_retry( + self, + project_id=self.project_id, + messages=messages_to_pass, + stream=True, + run_manager=run_manager, + **all_kwargs, + ): + try: + chunk, finish_reason = _convert_delta_response_to_message_chunk( + response=streamed_response, default_class=default_chunk_class + ) + generation_info = ( + dict(finish_reason=finish_reason) + if finish_reason is not None + else None + ) + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + if run_manager: + run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) + yield cg_chunk + except Exception as _: + continue + + +def create_prem_retry_decorator( + llm: ChatPremAI, + *, + max_retries: int = 1, + run_manager: Optional[Union[CallbackManagerForLLMRun]] = None, +) -> Callable[[Any], Any]: + import premai.models + + errors = [ + premai.models.api_response_validation_error.APIResponseValidationError, + premai.models.conflict_error.ConflictError, + premai.models.model_not_found_error.ModelNotFoundError, + premai.models.permission_denied_error.PermissionDeniedError, + premai.models.provider_api_connection_error.ProviderAPIConnectionError, + premai.models.provider_api_status_error.ProviderAPIStatusError, + premai.models.provider_api_timeout_error.ProviderAPITimeoutError, + premai.models.provider_internal_server_error.ProviderInternalServerError, + premai.models.provider_not_found_error.ProviderNotFoundError, + premai.models.rate_limit_error.RateLimitError, + premai.models.unprocessable_entity_error.UnprocessableEntityError, + premai.models.validation_error.ValidationError, + ] + + decorator = create_base_retry_decorator( + error_types=errors, max_retries=max_retries, run_manager=run_manager + ) + return decorator + + +def chat_with_retry( + llm: ChatPremAI, + project_id: int, + messages: List[dict], + stream: bool = False, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Using tenacity for retry in completion call""" + retry_decorator = create_prem_retry_decorator( + llm, max_retries=llm.max_retries, run_manager=run_manager + ) + + @retry_decorator + def _completion_with_retry( + project_id: int, + messages: List[dict], + stream: Optional[bool] = False, + **kwargs: Any, + ) -> Any: + response = llm.client.chat.completions.create( + project_id=project_id, + messages=messages, + stream=stream, + **kwargs, + ) + return response + + return _completion_with_retry( + project_id=project_id, + messages=messages, + stream=stream, + **kwargs, + ) diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 93d4b817bc8f1..b68c5115599c1 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -80,6 +80,7 @@ "VolcanoEmbeddings": "langchain_community.embeddings.volcengine", "VoyageEmbeddings": "langchain_community.embeddings.voyageai", "XinferenceEmbeddings": "langchain_community.embeddings.xinference", + "PremAIEmbeddings": "langchain_community.embeddings.premai", "YandexGPTEmbeddings": "langchain_community.embeddings.yandex", } diff --git a/libs/community/langchain_community/embeddings/premai.py b/libs/community/langchain_community/embeddings/premai.py new file mode 100644 index 0000000000000..e811b1bae49f4 --- /dev/null +++ b/libs/community/langchain_community/embeddings/premai.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional, Union + +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator +from langchain_core.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class PremAIEmbeddings(BaseModel, Embeddings): + """Prem's Embedding APIs""" + + project_id: int + """The project ID in which the experiments or deployments are carried out. + You can find all your projects here: https://app.premai.io/projects/""" + + premai_api_key: Optional[SecretStr] = None + """Prem AI API Key. Get it here: https://app.premai.io/api_keys/""" + + model: str + """The Embedding model to choose from""" + + show_progress_bar: bool = False + """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" + + max_retries: int = 1 + """Max number of retries for tenacity""" + + client: Any + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate that the package is installed and that the API token is valid""" + try: + from premai import Prem + except ImportError as error: + raise ImportError( + "Could not import Prem Python package." + "Please install it with: `pip install premai`" + ) from error + + try: + premai_api_key = get_from_dict_or_env( + values, "premai_api_key", "PREMAI_API_KEY" + ) + values["client"] = Prem(api_key=premai_api_key) + except Exception as error: + raise ValueError("Your API Key is incorrect. Please try again.") from error + return values + + def embed_query(self, text: str) -> List[float]: + """Embed query text""" + embeddings = embed_with_retry( + self, model=self.model, project_id=self.project_id, input=text + ) + return embeddings.data[0].embedding + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = embed_with_retry( + self, model=self.model, project_id=self.project_id, input=texts + ).data + + return [embedding.embedding for embedding in embeddings] + + +def create_prem_retry_decorator( + embedder: PremAIEmbeddings, + *, + max_retries: int = 1, +) -> Callable[[Any], Any]: + import premai.models + + errors = [ + premai.models.api_response_validation_error.APIResponseValidationError, + premai.models.conflict_error.ConflictError, + premai.models.model_not_found_error.ModelNotFoundError, + premai.models.permission_denied_error.PermissionDeniedError, + premai.models.provider_api_connection_error.ProviderAPIConnectionError, + premai.models.provider_api_status_error.ProviderAPIStatusError, + premai.models.provider_api_timeout_error.ProviderAPITimeoutError, + premai.models.provider_internal_server_error.ProviderInternalServerError, + premai.models.provider_not_found_error.ProviderNotFoundError, + premai.models.rate_limit_error.RateLimitError, + premai.models.unprocessable_entity_error.UnprocessableEntityError, + premai.models.validation_error.ValidationError, + ] + + decorator = create_base_retry_decorator( + error_types=errors, max_retries=max_retries, run_manager=None + ) + return decorator + + +def embed_with_retry( + embedder: PremAIEmbeddings, + model: str, + project_id: int, + input: Union[str, List[str]], +) -> Any: + """Using tenacity for retry in embedding calls""" + retry_decorator = create_prem_retry_decorator( + embedder, max_retries=embedder.max_retries + ) + + @retry_decorator + def _embed_with_retry( + embedder: PremAIEmbeddings, + project_id: int, + model: str, + input: Union[str, List[str]], + ) -> Any: + embedding_response = embedder.client.embeddings.create( + project_id=project_id, model=model, input=input + ) + return embedding_response + + return _embed_with_retry(embedder, project_id=project_id, model=model, input=input) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index e78b06a98c6cc..afcfe6f35ae99 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aenum" @@ -1167,10 +1167,7 @@ files = [ [package.dependencies] jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" -urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, - {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}, -] +urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""} [package.extras] crt = ["awscrt (==0.19.19)"] @@ -2689,6 +2686,37 @@ uvicorn = ">=0.23.2,<0.24.0" [package.extras] mllib = ["accelerate (==0.21.0)", "datasets (==2.16.0)", "einops (>=0.6.1,<0.7.0)", "h5py (>=3.9.0,<4.0.0)", "peft (==0.6.0)", "transformers (==4.36.2)"] +[[package]] +name = "friendli-client" +version = "1.3.1" +description = "Client of Friendli Suite." +optional = true +python-versions = "<4.0.0,>=3.8.1" +files = [ + {file = "friendli_client-1.3.1-py3-none-any.whl", hash = "sha256:1a77b046c57b0d70bac8d13ac6ecc861f8fc84d3c63e39b34543f862373a670b"}, + {file = "friendli_client-1.3.1.tar.gz", hash = "sha256:85f87976f7bb75eb424f384e3e73ac3256b7aad477361b51341e520c2aed3a0e"}, +] + +[package.dependencies] +fastapi = ">=0.104.0,<0.105.0" +gql = ">=3.4.1,<4.0.0" +httpx = ">=0.24.1,<0.25.0" +injector = ">=0.21.0,<0.22.0" +jsonschema = ">=4.17.3,<5.0.0" +pathspec = ">=0.9.0,<0.10.0" +protobuf = ">=4.24.2,<5.0.0" +pydantic = {version = ">=1.9.0,<3", extras = ["email"]} +PyYaml = ">=6.0.1,<7.0.0" +requests = ">=2,<3" +rich = ">=12.2.0,<13.0.0" +tqdm = ">=4.48.0,<5.0.0" +typer = ">=0.9.0,<0.10.0" +types-protobuf = ">=4.24.0.1,<5.0.0.0" +uvicorn = ">=0.23.2,<0.24.0" + +[package.extras] +mllib = ["accelerate (==0.21.0)", "datasets (==2.16.0)", "einops (>=0.6.1,<0.7.0)", "h5py (>=3.9.0,<4.0.0)", "peft (==0.6.0)", "transformers (==4.36.2)"] + [[package]] name = "frozenlist" version = "1.4.1" @@ -5965,6 +5993,23 @@ dev = ["packaging", "prawcore[lint]", "prawcore[test]"] lint = ["pre-commit", "ruff (>=0.0.291)"] test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"] +[[package]] +name = "premai" +version = "0.3.25" +description = "A client library for accessing Prem APIs" +optional = true +python-versions = ">=3.8,<4.0" +files = [ + {file = "premai-0.3.25-py3-none-any.whl", hash = "sha256:bddace7340e1827f048b410748d365e8663e4bbeb6bf7e8b8657f3cc267f7f28"}, + {file = "premai-0.3.25.tar.gz", hash = "sha256:c387980ecf3bdcb07886dd4f7a1c0f0701df67e772e62f444394cea97d5970a0"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +httpx = ">=0.20.0,<0.27.0" +python-dateutil = ">=2.8.0,<3.0.0" +typing_extensions = ">=4.9.0" + [[package]] name = "prometheus-client" version = "0.20.0" @@ -9071,20 +9116,6 @@ files = [ cryptography = ">=35.0.0" types-pyOpenSSL = "*" -[[package]] -name = "types-requests" -version = "2.31.0.6" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.7" -files = [ - {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, - {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, -] - -[package.dependencies] -types-urllib3 = "*" - [[package]] name = "types-requests" version = "2.31.0.20240311" @@ -9121,17 +9152,6 @@ files = [ {file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"}, ] -[[package]] -name = "types-urllib3" -version = "1.26.25.14" -description = "Typing stubs for urllib3" -optional = false -python-versions = "*" -files = [ - {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, - {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, -] - [[package]] name = "typing" version = "3.7.4.3" @@ -9228,22 +9248,6 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] -[[package]] -name = "urllib3" -version = "1.26.18" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, - {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, -] - -[package.extras] -brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - [[package]] name = "urllib3" version = "2.0.7" @@ -9319,7 +9323,6 @@ files = [ [package.dependencies] PyYAML = "*" -urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""} wrapt = "*" yarl = "*" @@ -9873,9 +9876,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "5b2a17ed079fa4cc1776f0474a9e73a428c10bf30a22b7185d2f7a77b2d146e5" +content-hash = "dcaae2110a70843fa3cb375618bebbe16b3da9bfdbc1e471e57f144d0906f58b" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 4fc7bea786057..3fc06bfb4eb31 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -97,6 +97,7 @@ rdflib = {version = "7.0.0", optional = true} nvidia-riva-client = {version = "^2.14.0", optional = true} tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true} friendli-client = {version = "^1.2.4", optional = true} +premai = {version = "^0.3.25", optional = true} [tool.poetry.group.test] optional = true @@ -267,7 +268,8 @@ extended_testing = [ "rdflib", "tidb-vector", "cloudpickle", - "friendli-client" + "friendli-client", + "premai" ] [tool.ruff] diff --git a/libs/community/tests/integration_tests/chat_models/test_premai.py b/libs/community/tests/integration_tests/chat_models/test_premai.py new file mode 100644 index 0000000000000..fae9b4135fb3c --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_premai.py @@ -0,0 +1,70 @@ +"""Test ChatPremAI from PremAI API wrapper. + +Note: This test must be run with the PREMAI_API_KEY environment variable set to a valid +API key and a valid project_id. +For this we need to have a project setup in PremAI's platform: https://app.premai.io +""" + +import pytest +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.chat_models import ChatPremAI + + +@pytest.fixture +def chat() -> ChatPremAI: + return ChatPremAI(project_id=8) + + +def test_chat_premai() -> None: + """Test ChatPremAI wrapper.""" + chat = ChatPremAI(project_id=8) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_prem_system_message() -> None: + """Test ChatPremAI wrapper for system message""" + chat = ChatPremAI(project_id=8) + system_message = SystemMessage(content="You are to chat with the user.") + human_message = HumanMessage(content="Hello") + response = chat([system_message, human_message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_prem_model() -> None: + """Test ChatPremAI wrapper handles model_name.""" + chat = ChatPremAI(model="foo", project_id=8) + assert chat.model == "foo" + + +def test_chat_prem_generate() -> None: + """Test ChatPremAI wrapper with generate.""" + chat = ChatPremAI(project_id=8) + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +async def test_prem_invoke(chat: ChatPremAI) -> None: + """Tests chat completion with invoke""" + result = chat.invoke("How is the weather in New York today?") + assert isinstance(result.content, str) + + +def test_prem_streaming() -> None: + """Test streaming tokens from Prem.""" + chat = ChatPremAI(project_id=8, streaming=True) + + for token in chat.stream("I'm Pickle Rick"): + assert isinstance(token.content, str) diff --git a/libs/community/tests/integration_tests/embeddings/test_premai.py b/libs/community/tests/integration_tests/embeddings/test_premai.py new file mode 100644 index 0000000000000..f0848760bfae9 --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_premai.py @@ -0,0 +1,40 @@ +"""Test PremAIEmbeddings from PremAI API wrapper. + +Note: This test must be run with the PREMAI_API_KEY environment variable set to a valid +API key and a valid project_id. This needs to setup a project in PremAI's platform. +You can check it out here: https://app.premai.io +""" + +import pytest + +from langchain_community.embeddings.premai import PremAIEmbeddings + + +@pytest.fixture +def embedder() -> PremAIEmbeddings: + return PremAIEmbeddings(project_id=8, model="text-embedding-3-small") + + +def test_prem_embedding_documents(embedder: PremAIEmbeddings) -> None: + """Test Prem embeddings.""" + documents = ["foo bar"] + output = embedder.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 1536 + + +def test_prem_embedding_documents_multiple(embedder: PremAIEmbeddings) -> None: + """Test prem embeddings for multiple queries or documents.""" + documents = ["foo bar", "bar foo", "foo"] + output = embedder.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 1536 + assert len(output[1]) == 1536 + assert len(output[2]) == 1536 + + +def test_prem_embedding_query(embedder: PremAIEmbeddings) -> None: + """Test Prem embeddings for single query""" + document = "foo bar" + output = embedder.embed_query(document) + assert len(output) == 1536 diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 07f64c7ad6440..cca1330eaa599 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -44,6 +44,7 @@ "ChatPerplexity", "ChatKinetica", "ChatFriendli", + "ChatPremAI", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_premai.py b/libs/community/tests/unit_tests/chat_models/test_premai.py new file mode 100644 index 0000000000000..c72d4f0ec865b --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_premai.py @@ -0,0 +1,47 @@ +"""Test PremChat model""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.chat_models import ChatPremAI +from langchain_community.chat_models.premai import _messages_to_prompt_dict + + +@pytest.mark.requires("premai") +def test_api_key_is_string() -> None: + llm = ChatPremAI(premai_api_key="secret-api-key", project_id=8) + assert isinstance(llm.premai_api_key, SecretStr) + + +@pytest.mark.requires("premai") +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = ChatPremAI(premai_api_key="secret-api-key", project_id=8) + print(llm.premai_api_key, end="") # noqa: T201 + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_messages_to_prompt_dict_with_valid_messages() -> None: + system_message, result = _messages_to_prompt_dict( + [ + SystemMessage(content="System Prompt"), + HumanMessage(content="User message #1"), + AIMessage(content="AI message #1"), + HumanMessage(content="User message #2"), + AIMessage(content="AI message #2"), + ] + ) + expected = [ + {"role": "user", "content": "User message #1"}, + {"role": "assistant", "content": "AI message #1"}, + {"role": "user", "content": "User message #2"}, + {"role": "assistant", "content": "AI message #2"}, + ] + + assert system_message == "System Prompt" + assert result == expected diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 1ed14d49e0b0e..48e3b6cf65e06 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -65,6 +65,7 @@ "QuantizedBiEncoderEmbeddings", "NeMoEmbeddings", "SparkLLMTextEmbeddings", + "PremAIEmbeddings", "YandexGPTEmbeddings", ] diff --git a/libs/community/tests/unit_tests/embeddings/test_premai.py b/libs/community/tests/unit_tests/embeddings/test_premai.py new file mode 100644 index 0000000000000..3b06b19026ae2 --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_premai.py @@ -0,0 +1,28 @@ +"""Test EmbaasEmbeddings embeddings""" + +import pytest +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.embeddings import PremAIEmbeddings + + +@pytest.mark.requires("premai") +def test_api_key_is_string() -> None: + llm = PremAIEmbeddings( + premai_api_key="secret-api-key", project_id=8, model="fake-model" + ) + assert isinstance(llm.premai_api_key, SecretStr) + + +@pytest.mark.requires("premai") +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = PremAIEmbeddings( + premai_api_key="secret-api-key", project_id=8, model="fake-model" + ) + print(llm.premai_api_key, end="") # noqa: T201 + captured = capsys.readouterr() + + assert captured.out == "**********"