Skip to content

Commit

Permalink
partner-upstage[patch]: embeddings empty list bug (#22057)
Browse files Browse the repository at this point in the history
Fixed an error in `embed_documents` when the input was given as an empty
list. And I have revised the document.
  • Loading branch information
JuHyung-Son authored May 23, 2024
1 parent 2df8ac4 commit d9eff44
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
"\n",
"docs = loader.load()\n",
"\n",
"vectorstore = DocArrayInMemorySearch.from_documents(docs, embedding=UpstageEmbeddings())\n",
"vectorstore = DocArrayInMemorySearch.from_documents(\n",
" docs, embedding=UpstageEmbeddings(model=\"solar-embedding-1-large\")\n",
")\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"template = \"\"\"Answer the question based only on the following context:\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/integrations/providers/upstage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@
"source": [
"from langchain_upstage import UpstageEmbeddings\n",
"\n",
"embeddings = UpstageEmbeddings()\n",
"embeddings = UpstageEmbeddings(model=\"solar-embedding-1-large\")\n",
"doc_result = embeddings.embed_documents(\n",
" [\"Sam is a teacher.\", \"This is another document\"]\n",
" [\"Sung is a professor.\", \"This is another document\"]\n",
")\n",
"print(doc_result)\n",
"\n",
"query_result = embeddings.embed_query(\"What does Sam do?\")\n",
"query_result = embeddings.embed_query(\"What does Sung do?\")\n",
"print(query_result)"
]
},
Expand Down
8 changes: 4 additions & 4 deletions docs/docs/integrations/text_embedding/upstage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"source": [
"from langchain_upstage import UpstageEmbeddings\n",
"\n",
"embeddings = UpstageEmbeddings()"
"embeddings = UpstageEmbeddings(model=\"solar-embedding-1-large\")"
]
},
{
Expand All @@ -101,7 +101,7 @@
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents(\n",
" [\"Sam is a teacher.\", \"This is another document\"]\n",
" [\"Sung is a professor.\", \"This is another document\"]\n",
")\n",
"print(doc_result)"
]
Expand All @@ -123,7 +123,7 @@
},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"What does Sam do?\")\n",
"query_result = embeddings.embed_query(\"What does Sung do?\")\n",
"print(query_result)"
]
},
Expand Down Expand Up @@ -184,7 +184,7 @@
"\n",
"vectorstore = DocArrayInMemorySearch.from_texts(\n",
" [\"harrison worked at kensho\", \"bears like to eat honey\"],\n",
" embedding=UpstageEmbeddings(),\n",
" embedding=UpstageEmbeddings(model=\"solar-embedding-1-large\"),\n",
")\n",
"retriever = vectorstore.as_retriever()\n",
"docs = retriever.invoke(\"Where did Harrison work?\")\n",
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/upstage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ See a [usage example](https://python.langchain.com/docs/integrations/chat/upstag

See a [usage example](https://python.langchain.com/docs/integrations/text_embedding/upstage)

Use `solar-1-mini-embedding` as the default model for embeddings. Do not add suffixes such as `-query` or `-passage` to the model name.
Use `solar-embedding-1-large` model for embeddings. Do not add suffixes such as `-query` or `-passage` to the model name.
`UpstageEmbeddings` will automatically add the suffixes based on the method called.
6 changes: 5 additions & 1 deletion libs/partners/upstage/langchain_upstage/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class UpstageEmbeddings(BaseModel, Embeddings):
from langchain_upstage import UpstageEmbeddings
model = UpstageEmbeddings()
model = UpstageEmbeddings(model='solar-embedding-1-large')
"""

client: Any = Field(default=None, exclude=True) #: :meta private:
Expand Down Expand Up @@ -200,6 +200,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
if not texts:
return []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
embeddings = []
Expand Down Expand Up @@ -242,6 +244,8 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
if not texts:
return []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
embeddings = []
Expand Down
14 changes: 14 additions & 0 deletions libs/partners/upstage/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,17 @@ async def test_langchain_upstage_aembed_query() -> None:
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_query(query)
assert len(output) > 0


def test_langchain_upstage_embed_documents_with_empty_list() -> None:
"""Test Upstage embeddings with empty list."""
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = embedding.embed_documents([])
assert len(output) == 0


async def test_langchain_upstage_aembed_documents_with_empty_list() -> None:
"""Test Upstage embeddings asynchronous with empty list."""
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_documents([])
assert len(output) == 0

0 comments on commit d9eff44

Please sign in to comment.