Skip to content

Commit

Permalink
Support async for embedding micorservice (opea-project#742)
Browse files Browse the repository at this point in the history
* Support async in embedding micorservice

Signed-off-by: lvliang-intel <[email protected]>
  • Loading branch information
lvliang-intel authored Sep 26, 2024
1 parent 2159f9a commit 2867295
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions comps/embeddings/tei/langchain/embedding_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@
port=6000,
)
@register_statistics(names=["opea_service@embedding_tei_langchain"])
def embedding(
async def embedding(
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
start = time.time()
if logflag:
logger.info(input)
if isinstance(input, TextDoc):
embed_vector = embeddings.embed_query(input.text)
embed_vector = await embeddings.aembed_query(input.text)
res = EmbedDoc(text=input.text, embedding=embed_vector)
else:
embed_vector = embeddings.embed_query(input.input)
embed_vector = await embeddings.aembed_query(input.input)
if input.dimensions is not None:
embed_vector = embed_vector[: input.dimensions]

Expand Down
4 changes: 2 additions & 2 deletions comps/embeddings/tei/langchain/local_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
output_datatype=EmbedDoc,
)
@opea_telemetry
def embedding(input: TextDoc) -> EmbedDoc:
async def embedding(input: TextDoc) -> EmbedDoc:
if logflag:
logger.info(input)
embed_vector = embeddings.embed_query(input.text)
embed_vector = await embeddings.aembed_query(input.text)
res = EmbedDoc(text=input.text, embedding=embed_vector)
if logflag:
logger.info(res)
Expand Down
4 changes: 2 additions & 2 deletions comps/embeddings/tei/langchain/local_embedding_768.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
output_datatype=EmbedDoc768,
)
@opea_telemetry
def embedding(input: TextDoc) -> EmbedDoc768:
embed_vector = embeddings.embed_query(input.text)
async def embedding(input: TextDoc) -> EmbedDoc768:
embed_vector = await embeddings.aembed_query(input.text)
res = EmbedDoc768(text=input.text, embedding=embed_vector)
return res

Expand Down
4 changes: 2 additions & 2 deletions comps/embeddings/tei/llama_index/embedding_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
input_datatype=TextDoc,
output_datatype=EmbedDoc,
)
def embedding(input: TextDoc) -> EmbedDoc:
async def embedding(input: TextDoc) -> EmbedDoc:
if logflag:
logger.info(input)
embed_vector = embeddings._get_query_embedding(input.text)
embed_vector = await embeddings.aget_query_embedding(input.text)
res = EmbedDoc(text=input.text, embedding=embed_vector)
if logflag:
logger.info(res)
Expand Down
4 changes: 2 additions & 2 deletions comps/embeddings/tei/llama_index/local_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
input_datatype=TextDoc,
output_datatype=EmbedDoc,
)
def embedding(input: TextDoc) -> EmbedDoc:
async def embedding(input: TextDoc) -> EmbedDoc:
if logflag:
logger.info(input)
embed_vector = embeddings.get_text_embedding(input.text)
embed_vector = await embeddings.aget_query_embedding(input.text)
res = EmbedDoc(text=input.text, embedding=embed_vector)
if logflag:
logger.info(res)
Expand Down

0 comments on commit 2867295

Please sign in to comment.