Skip to content

Commit

Permalink
Async support for some microservices (#763)
Browse files Browse the repository at this point in the history
* Async support for some microservices

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix issues

Signed-off-by: lvliang-intel <[email protected]>

* fix issues

Signed-off-by: lvliang-intel <[email protected]>

* fix import issue

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dependency library

Signed-off-by: lvliang-intel <[email protected]>

* fix issue

Signed-off-by: lvliang-intel <[email protected]>

* roll back pinecone change

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: lvliang-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lvliang-intel and pre-commit-ci[bot] authored Oct 10, 2024
1 parent ceba539 commit f3746dc
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 61 deletions.
4 changes: 2 additions & 2 deletions comps/embeddings/mosec/langchain/embedding_mosec.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def empty_embedding() -> List[float]:
output_datatype=EmbedDoc,
)
@register_statistics(names=["opea_service@embedding_mosec"])
def embedding(input: TextDoc) -> EmbedDoc:
async def embedding(input: TextDoc) -> EmbedDoc:
if logflag:
logger.info(input)
start = time.time()
embed_vector = embeddings.embed_query(input.text)
embed_vector = await embeddings.aembed_query(input.text)
res = EmbedDoc(text=input.text, embedding=embed_vector)
statistics_dict["opea_service@embedding_mosec"].append_latency(time.time() - start, None)
if logflag:
Expand Down
4 changes: 2 additions & 2 deletions comps/intent_detection/langchain/intent_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
Expand All @@ -35,7 +35,7 @@ def llm_generate(input: LLMParamsDoc):

llm_chain = LLMChain(prompt=prompt, llm=llm)

response = llm_chain.invoke(input.query)
response = await llm_chain.ainvoke(input.query)
response = response["text"]
print("response", response)
return GeneratedDoc(text=response, prompt=input.query)
Expand Down
11 changes: 5 additions & 6 deletions comps/llms/faq-generation/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_tokens,
Expand All @@ -54,9 +53,6 @@ def llm_generate(input: LLMParamsDoc):
"""
PROMPT = PromptTemplate.from_template(templ)
llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT)

# Split text
text_splitter = CharacterTextSplitter()
texts = text_splitter.split_text(input.query)

# Create multiple documents
Expand All @@ -77,12 +73,15 @@ async def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm_chain.invoke(docs)
response = await llm_chain.ainvoke(docs)
response = response["output_text"]
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)


if __name__ == "__main__":
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
# Split text
text_splitter = CharacterTextSplitter()
opea_microservices["opea_service@llm_faqgen"].start()
12 changes: 6 additions & 6 deletions comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")

llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_tokens,
Expand All @@ -48,9 +48,6 @@ def llm_generate(input: LLMParamsDoc):
streaming=input.streaming,
)
llm_chain = load_summarize_chain(llm=llm, chain_type="map_reduce")

# Split text
text_splitter = CharacterTextSplitter()
texts = text_splitter.split_text(input.query)

# Create multiple documents
Expand All @@ -71,12 +68,15 @@ async def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm_chain.invoke(docs)
response = await llm_chain.ainvoke(docs)
response = response["output_text"]
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)


if __name__ == "__main__":
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
# Split text
text_splitter = CharacterTextSplitter()
opea_microservices["opea_service@llm_docsum"].start()
4 changes: 2 additions & 2 deletions comps/llms/text-generation/ollama/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
ollama = Ollama(
Expand Down Expand Up @@ -48,7 +48,7 @@ async def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = ollama.invoke(input.query)
response = await ollama.ainvoke(input.query)
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/ray_serve/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("RAY_Serve_ENDPOINT", "http://localhost:8080")
llm_model = os.getenv("LLM_MODEL", "Llama-2-7b-chat-hf")
if "/" in llm_model:
Expand Down Expand Up @@ -73,7 +73,7 @@ async def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(input.query)
response = await llm.ainvoke(input.query)
response = response.content
return GeneratedDoc(text=response, prompt=input.query)

Expand Down
6 changes: 3 additions & 3 deletions comps/llms/text-generation/vllm/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)

Expand Down Expand Up @@ -102,7 +102,7 @@ async def stream_generator():
return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(new_input.query, **parameters)
response = await llm.ainvoke(new_input.query, **parameters)
if logflag:
logger.info(response)

Expand Down Expand Up @@ -153,7 +153,7 @@ async def stream_generator():
return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(prompt, **parameters)
response = await llm.ainvoke(prompt, **parameters)
if logflag:
logger.info(response)

Expand Down
8 changes: 4 additions & 4 deletions comps/llms/text-generation/vllm/llama_index/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
Expand All @@ -56,8 +56,8 @@ def llm_generate(input: LLMParamsDoc):

if input.streaming:

def stream_generator():
for text in llm.stream_complete(input.query):
async def stream_generator():
async for text in llm.astream_complete(input.query):
output = text.text
yield f"data: {output}\n\n"
if logflag:
Expand All @@ -66,7 +66,7 @@ def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.complete(input.query).text
response = await llm.acomplete(input.query).text
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)
Expand Down
8 changes: 4 additions & 4 deletions comps/llms/text-generation/vllm/ray/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
if logflag:
logger.info(input)
llm_endpoint = os.getenv("vLLM_RAY_ENDPOINT", "http://localhost:8006")
Expand All @@ -50,9 +50,9 @@ def llm_generate(input: LLMParamsDoc):

if input.streaming:

def stream_generator():
async def stream_generator():
chat_response = ""
for text in llm.stream(input.query):
for text in llm.astream(input.query):
text = text.content
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
Expand All @@ -63,7 +63,7 @@ def stream_generator():

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(input.query)
response = await llm.ainvoke(input.query)
response = response.content
if logflag:
logger.info(response)
Expand Down
2 changes: 1 addition & 1 deletion comps/lvms/predictionguard/lvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class LVMDoc(BaseDoc):
output_datatype=TextDoc,
)
@register_statistics(names=["opea_service@lvm_predictionguard"])
async def lvm(request: LVMDoc) -> TextDoc:
def lvm(request: LVMDoc) -> TextDoc:
start = time.time()

# make a request to the Prediction Guard API using the LlaVa model
Expand Down
10 changes: 5 additions & 5 deletions comps/retrievers/milvus/langchain/retriever_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def empty_embedding() -> List[float]:
port=7000,
)
@register_statistics(names=["opea_service@retriever_milvus"])
def retrieve(input: EmbedDoc) -> SearchedDoc:
async def retrieve(input: EmbedDoc) -> SearchedDoc:
if logflag:
logger.info(input)
vector_db = Milvus(
Expand All @@ -77,20 +77,20 @@ def retrieve(input: EmbedDoc) -> SearchedDoc:
)
start = time.time()
if input.search_type == "similarity":
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k)
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
search_res = vector_db.similarity_search_by_vector(
search_res = await vector_db.asimilarity_search_by_vector(
embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold
)
elif input.search_type == "similarity_score_threshold":
docs_and_similarities = vector_db.similarity_search_with_relevance_scores(
docs_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(
query=input.text, k=input.k, score_threshold=input.score_threshold
)
search_res = [doc for doc, _ in docs_and_similarities]
elif input.search_type == "mmr":
search_res = vector_db.max_marginal_relevance_search(
search_res = await vector_db.amax_marginal_relevance_search(
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult
)
searched_docs = []
Expand Down
10 changes: 5 additions & 5 deletions comps/retrievers/multimodal/redis/langchain/retriever_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
port=7000,
)
@register_statistics(names=["opea_service@multimodal_retriever_redis"])
def retrieve(
async def retrieve(
input: Union[EmbedMultimodalDoc, RetrievalRequest, ChatCompletionRequest]
) -> Union[SearchedMultimodalDoc, RetrievalResponse, ChatCompletionRequest]:

Expand All @@ -45,20 +45,20 @@ def retrieve(
else:
# if the Redis index has data, perform the search
if input.search_type == "similarity":
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, k=input.k)
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding, k=input.k)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
search_res = vector_db.similarity_search_by_vector(
search_res = await vector_db.asimilarity_search_by_vector(
embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold
)
elif input.search_type == "similarity_score_threshold":
docs_and_similarities = vector_db.similarity_search_with_relevance_scores(
docs_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(
query=input.text, k=input.k, score_threshold=input.score_threshold
)
search_res = [doc for doc, _ in docs_and_similarities]
elif input.search_type == "mmr":
search_res = vector_db.max_marginal_relevance_search(
search_res = await vector_db.amax_marginal_relevance_search(
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult
)
else:
Expand Down
12 changes: 7 additions & 5 deletions comps/retrievers/neo4j/langchain/retriever_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
port=7000,
)
@register_statistics(names=["opea_service@retriever_neo4j"])
def retrieve(
async def retrieve(
input: Union[EmbedDoc, RetrievalRequest, ChatCompletionRequest]
) -> Union[SearchedDoc, RetrievalResponse, ChatCompletionRequest]:
if logflag:
Expand All @@ -54,20 +54,22 @@ def retrieve(
query = input.input

if input.search_type == "similarity":
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding, query=input.text, k=input.k)
search_res = await vector_db.asimilarity_search_by_vector(
embedding=input.embedding, query=input.text, k=input.k
)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError("distance_threshold must be provided for " + "similarity_distance_threshold retriever")
search_res = vector_db.similarity_search_by_vector(
search_res = await vector_db.asimilarity_search_by_vector(
embedding=input.embedding, query=input.text, k=input.k, distance_threshold=input.distance_threshold
)
elif input.search_type == "similarity_score_threshold":
docs_and_similarities = vector_db.similarity_search_with_relevance_scores(
docs_and_similarities = await vector_db.asimilarity_search_with_relevance_scores(
query=input.text, k=input.k, score_threshold=input.score_threshold
)
search_res = [doc for doc, _ in docs_and_similarities]
elif input.search_type == "mmr":
search_res = vector_db.max_marginal_relevance_search(
search_res = await vector_db.amax_marginal_relevance_search(
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult
)
else:
Expand Down
4 changes: 2 additions & 2 deletions comps/retrievers/pgvector/langchain/retriever_pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
port=PORT,
)
@register_statistics(names=["opea_service@retriever_pgvector"])
def retrieve(input: EmbedDoc) -> SearchedDoc:
async def retrieve(input: EmbedDoc) -> SearchedDoc:
if logflag:
logger.info(input)
start = time.time()
search_res = vector_db.similarity_search_by_vector(embedding=input.embedding)
search_res = await vector_db.asimilarity_search_by_vector(embedding=input.embedding)
searched_docs = []
for r in search_res:
searched_docs.append(TextDoc(text=r.page_content))
Expand Down
1 change: 1 addition & 0 deletions comps/retrievers/redis/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ docarray[full]
easyocr
fastapi
langchain_community
langchain_huggingface
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Expand Down
Loading

0 comments on commit f3746dc

Please sign in to comment.