Skip to content

Commit

Permalink
fix: add guidance parameters for LC wrapper models (#255)
Browse files Browse the repository at this point in the history
* fix: add docstring to LC wrapper models

* fix: fix metadata passing with LC embedding wrapper
  • Loading branch information
taprosoft authored Sep 9, 2024
1 parent ce48972 commit 96d2086
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 18 deletions.
1 change: 1 addition & 0 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
"__type__": "kotaemon.embeddings.LCCohereEmbeddings",
"model": "embed-multilingual-v2.0",
"cohere_api_key": "your-key",
"user_agent": "default",
},
"default": False,
}
Expand Down
44 changes: 28 additions & 16 deletions libs/kotaemon/kotaemon/embeddings/langchain_based.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from kotaemon.base import Document, DocumentWithEmbedding
from kotaemon.base import DocumentWithEmbedding, Param

from .base import BaseEmbeddings

Expand All @@ -19,25 +19,14 @@ def __init__(self, **params):
super().__init__()

def run(self, text):
input_: list[str] = []
if not isinstance(text, list):
text = [text]

for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
input_docs = self.prepare_input(text)
input_ = [doc.text for doc in input_docs]

embeddings = self._obj.embed_documents(input_)

return [
DocumentWithEmbedding(text=each_text, embedding=each_embedding)
for each_text, each_embedding in zip(input_, embeddings)
DocumentWithEmbedding(content=doc, embedding=each_embedding)
for doc, each_embedding in zip(input_docs, embeddings)
]

def __repr__(self):
Expand Down Expand Up @@ -162,6 +151,20 @@ def _get_lc_class(self):
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""

cohere_api_key: str = Param(
help="API key (https://dashboard.cohere.com/api-keys)",
default=None,
required=True,
)
model: str = Param(
help="Model name to use (https://docs.cohere.com/docs/models)",
default=None,
required=True,
)
user_agent: str = Param(
help="User agent (leave default)", default="default", required=True
)

def __init__(
self,
model: str = "embed-english-v2.0",
Expand Down Expand Up @@ -190,6 +193,15 @@ def _get_lc_class(self):
class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""

model_name: str = Param(
help=(
"Model name to use (https://huggingface.co/models?"
"pipeline_tag=sentence-similarity&sort=trending)"
),
default=None,
required=True,
)

def __init__(
self,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
Expand Down
24 changes: 23 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import AsyncGenerator, Iterator

from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param

from .base import ChatLLM

Expand Down Expand Up @@ -224,6 +224,17 @@ def _get_lc_class(self):


class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://console.anthropic.com/settings/keys)", required=True
)
model_name: str = Param(
help=(
"Model name to use "
"(https://docs.anthropic.com/en/docs/about-claude/models)"
),
required=True,
)

def __init__(
self,
api_key: str | None = None,
Expand All @@ -248,6 +259,17 @@ def _get_lc_class(self):


class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
api_key: str = Param(
help="API key (https://aistudio.google.com/app/apikey)", required=True
)
model_name: str = Param(
help=(
"Model name to use (https://cloud.google"
".com/vertex-ai/generative-ai/docs/learn/models)"
),
required=True,
)

def __init__(
self,
api_key: str | None = None,
Expand Down
1 change: 1 addition & 0 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def load(self):
}
if item.default:
self._default = item.name
self._models["default"] = self._models[item.name]

def load_vendors(self):
from kotaemon.embeddings import (
Expand Down
2 changes: 1 addition & 1 deletion libs/ktem/ktem/index/file/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_user_settings(self):
def get_admin_settings(cls):
from ktem.embeddings.manager import embedding_models_manager

embedding_default = embedding_models_manager.get_default_name()
embedding_default = "default"
embedding_choices = list(embedding_models_manager.options().keys())

return {
Expand Down

0 comments on commit 96d2086

Please sign in to comment.