Skip to content

Commit

Permalink
Merge branch 'langchain_opea' of https://github.com/yogeshmpandey/Gen…
Browse files Browse the repository at this point in the history
…AIComps into langchain_opea
  • Loading branch information
yogeshmpandey committed Jan 9, 2025
2 parents 1a23bbd + f91e3ea commit 9e7de88
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 23 deletions.
1 change: 0 additions & 1 deletion comps/integrations/langchain/langchain_opea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from langchain_opea.embeddings import OPEAEmbeddings
from langchain_opea.llms import OPEALLM


__all__ = [
"ChatOPEA",
"OPEALLM",
Expand Down
24 changes: 6 additions & 18 deletions comps/integrations/langchain/langchain_opea/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import (
Any,
Dict,
List,
Optional,
)
from typing import Any, Dict, List, Optional

import openai
from langchain_core.language_models.chat_models import LangSmithParams
Expand All @@ -17,6 +12,7 @@

DEFAULT_MODEL_ID = "Intel/neural-chat-7b-v3-3"


class ChatOPEA(BaseChatOpenAI): # type: ignore[override]
"""OPEA OPENAI Compatible Chat large language models.
Expand Down Expand Up @@ -90,9 +86,7 @@ def _llm_type(self) -> str:
"""Return type of chat model."""
return "opea-chat"

def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
def _get_ls_params(self, stop: Optional[List[str]] = None, **kwargs: Any) -> LangSmithParams:
"""Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs)
params["ls_provider"] = "opea"
Expand All @@ -107,9 +101,7 @@ def validate_environment(self) -> Self:
raise ValueError("n must be 1 when streaming.")

client_params: dict = {
"api_key": (
self.opea_api_key.get_secret_value() if self.opea_api_key else None
),
"api_key": (self.opea_api_key.get_secret_value() if self.opea_api_key else None),
"base_url": self.opea_api_base,
}

Expand All @@ -121,14 +113,10 @@ def validate_environment(self) -> Self:

if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
self.client = openai.OpenAI(**client_params, **sync_specific).chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
self.async_client = openai.AsyncOpenAI(**client_params, **async_specific).chat.completions
return self

@property
Expand Down
6 changes: 4 additions & 2 deletions comps/integrations/langchain/langchain_opea/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OPEAEmbeddings(OpenAIEmbeddings):
.. code-block:: python
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
"""
"""

model_name: str = Field(alias="model")
"""Model name to use."""
Expand Down Expand Up @@ -82,7 +82,9 @@ def validate_environment(self) -> Self:

@property
def _invocation_params(self) -> Dict[str, Any]:
openai_params = {"model": self.model_name,}
openai_params = {
"model": self.model_name,
}
return {**openai_params, **super()._invocation_params}

@property
Expand Down
6 changes: 4 additions & 2 deletions comps/integrations/langchain/langchain_opea/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations # type: ignore[import-not-found]

from typing import Any, Dict, List, Optional

import openai
Expand All @@ -12,6 +13,7 @@

DEFAULT_MODEL_ID = "Intel/neural-chat-7b-v3-3"


class OPEALLM(BaseOpenAI): # type: ignore[override]
"""OPEA OPENAI Compatible LLM Endpoints.
Expand Down Expand Up @@ -80,8 +82,8 @@ def validate_environment(self) -> Self:

if client_params["api_key"] is None:
raise ValueError(
"OPEA_API_KEY is not set. Please set it in the `opea_api_key` field or "
"in the `OPEA_API_KEY` environment variable."
"OPEA_API_KEY is not set. Please set it in the `opea_api_key` field or "
"in the `OPEA_API_KEY` environment variable."
)

if not self.client:
Expand Down

0 comments on commit 9e7de88

Please sign in to comment.