Skip to content

Commit

Permalink
Support SearchedDoc input type in LLM for No Rerank Pipeline (#541)
Browse files Browse the repository at this point in the history
Signed-off-by: letonghan <[email protected]>
  • Loading branch information
letonghan authored Aug 21, 2024
1 parent 04ff8bf commit 3c29fb4
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 15 deletions.
12 changes: 6 additions & 6 deletions comps/llms/text-generation/tgi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ pip install -r requirements.txt

```bash
export HF_TOKEN=${your_hf_api_token}
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=${your_langchain_api_key}
export LANGCHAIN_PROJECT="opea/gen-ai-comps:llms"
docker run -p 8008:80 -v ./data:/data --name tgi_service --shm-size 1g ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id ${your_hf_llm_model}
```

Expand Down Expand Up @@ -50,9 +47,6 @@ In order to start TGI and LLM services, you need to setup the following environm
export HF_TOKEN=${your_hf_api_token}
export TGI_LLM_ENDPOINT="http://${your_ip}:8008"
export LLM_MODEL_ID=${your_hf_llm_model}
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=${your_langchain_api_key}
export LANGCHAIN_PROJECT="opea/llms"
```

### 2.2 Build Docker Image
Expand Down Expand Up @@ -116,6 +110,12 @@ curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'

# consume with SearchedDoc
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

### 4. Validated Model
Expand Down
8 changes: 7 additions & 1 deletion comps/llms/text-generation/tgi/docker_compose_llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ services:
volumes:
- "./data:/data"
shm_size: 1g
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
HF_TOKEN: ${HF_TOKEN}
HF_HUB_DISABLE_PROGRESS_BARS: 1
HF_HUB_ENABLE_HF_TRANSFER: 0
command: --model-id ${LLM_MODEL_ID}
llm:
image: opea/llm-tgi:latest
Expand All @@ -25,7 +32,6 @@ services:
https_proxy: ${https_proxy}
TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT}
HF_TOKEN: ${HF_TOKEN}
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY}
restart: unless-stopped

networks:
Expand Down
71 changes: 63 additions & 8 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CustomLogger,
GeneratedDoc,
LLMParamsDoc,
SearchedDoc,
ServiceType,
opea_microservices,
register_microservice,
Expand All @@ -41,26 +42,76 @@
port=9000,
)
@register_statistics(names=["opea_service@llm_tgi"])
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)
prompt_template = None
if input.chat_template:
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables

stream_gen_time = []
start = time.time()

if isinstance(input, LLMParamsDoc):
if isinstance(input, SearchedDoc):
if logflag:
logger.info("[ SearchedDoc ] input from retriever microservice")
prompt = input.initial_query
if input.retrieved_docs:
if logflag:
logger.info(f"[ SearchedDoc ] retrieved docs: {input.retrieved_docs}")
for doc in input.retrieved_docs:
logger.info(f"[ SearchedDoc ] {doc}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, input.retrieved_docs[0].text)
# use default llm parameters for inferencing
new_input = LLMParamsDoc(query=prompt)
if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")
text_generation = await llm.text_generation(
prompt=prompt,
stream=new_input.streaming,
max_new_tokens=new_input.max_new_tokens,
repetition_penalty=new_input.repetition_penalty,
temperature=new_input.temperature,
top_k=new_input.top_k,
top_p=new_input.top_p,
)
if new_input.streaming:

async def stream_generator():
chat_response = ""
async for text in text_generation:
stream_gen_time.append(time.time() - start)
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ SearchedDoc ] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ SearchedDoc ] stream response: {chat_response}")
statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0])
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
statistics_dict["opea_service@llm_tgi"].append_latency(time.time() - start, None)
if logflag:
logger.info(text_generation)
return GeneratedDoc(text=text_generation, prompt=new_input.query)

elif isinstance(input, LLMParamsDoc):
if logflag:
logger.info("[ LLMParamsDoc ] input from rerank microservice")
prompt = input.query
if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.query)
else:
logger.info(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']")
logger.info(
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
Expand All @@ -84,10 +135,10 @@ async def stream_generator():
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[llm - chat_stream] chunk:{chunk_repr}")
logger.info(f"[ LLMParamsDoc ] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[llm - chat_stream] stream response: {chat_response}")
logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}")
statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0])
yield "data: [DONE]\n\n"

Expand All @@ -99,6 +150,8 @@ async def stream_generator():
return GeneratedDoc(text=text_generation, prompt=input.query)

else:
if logflag:
logger.info("[ ChatCompletionRequest ] input in opea format")
client = OpenAI(
api_key="EMPTY",
base_url=llm_endpoint + "/v1",
Expand All @@ -113,7 +166,7 @@ async def stream_generator():
prompt = prompt_template.format(question=input.messages)
else:
logger.info(
f"{prompt_template} not used, we only support 2 input variables ['question', 'context']"
f"[ ChatCompletionRequest ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
Expand Down Expand Up @@ -152,7 +205,9 @@ async def stream_generator():
if input_variables == ["context"]:
system_prompt = prompt_template.format(context="\n".join(input.documents))
else:
logger.info(f"{prompt_template} not used, only support 1 input variables ['context']")
logger.info(
f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']"
)

input.messages.insert(0, {"role": "system", "content": system_prompt})

Expand Down

0 comments on commit 3c29fb4

Please sign in to comment.