Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SearchedDoc input type in LLM for No Rerank Pipeline #541

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions comps/llms/text-generation/tgi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ pip install -r requirements.txt

```bash
export HF_TOKEN=${your_hf_api_token}
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=${your_langchain_api_key}
export LANGCHAIN_PROJECT="opea/gen-ai-comps:llms"
docker run -p 8008:80 -v ./data:/data --name tgi_service --shm-size 1g ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id ${your_hf_llm_model}
```

Expand Down Expand Up @@ -50,9 +47,6 @@ In order to start TGI and LLM services, you need to setup the following environm
export HF_TOKEN=${your_hf_api_token}
export TGI_LLM_ENDPOINT="http://${your_ip}:8008"
export LLM_MODEL_ID=${your_hf_llm_model}
export LANGCHAIN_TRACING_V2=true
export LANGCHAIN_API_KEY=${your_langchain_api_key}
export LANGCHAIN_PROJECT="opea/llms"
```

### 2.2 Build Docker Image
Expand Down Expand Up @@ -116,6 +110,12 @@ curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'

# consume with SearchedDoc
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

### 4. Validated Model
Expand Down
8 changes: 7 additions & 1 deletion comps/llms/text-generation/tgi/docker_compose_llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ services:
volumes:
- "./data:/data"
shm_size: 1g
environment:
no_proxy: ${no_proxy}
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
HF_TOKEN: ${HF_TOKEN}
HF_HUB_DISABLE_PROGRESS_BARS: 1
HF_HUB_ENABLE_HF_TRANSFER: 0
command: --model-id ${LLM_MODEL_ID}
llm:
image: opea/llm-tgi:latest
Expand All @@ -25,7 +32,6 @@ services:
https_proxy: ${https_proxy}
TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT}
HF_TOKEN: ${HF_TOKEN}
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY}
restart: unless-stopped

networks:
Expand Down
71 changes: 63 additions & 8 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CustomLogger,
GeneratedDoc,
LLMParamsDoc,
SearchedDoc,
ServiceType,
opea_microservices,
register_microservice,
Expand All @@ -41,26 +42,76 @@
port=9000,
)
@register_statistics(names=["opea_service@llm_tgi"])
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)
prompt_template = None
if input.chat_template:
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables

stream_gen_time = []
start = time.time()

if isinstance(input, LLMParamsDoc):
if isinstance(input, SearchedDoc):
if logflag:
logger.info("[ SearchedDoc ] input from retriever microservice")
prompt = input.initial_query
if input.retrieved_docs:
if logflag:
logger.info(f"[ SearchedDoc ] retrieved docs: {input.retrieved_docs}")
for doc in input.retrieved_docs:
logger.info(f"[ SearchedDoc ] {doc}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, input.retrieved_docs[0].text)
# use default llm parameters for inferencing
new_input = LLMParamsDoc(query=prompt)
if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")
text_generation = await llm.text_generation(
prompt=prompt,
stream=new_input.streaming,
max_new_tokens=new_input.max_new_tokens,
repetition_penalty=new_input.repetition_penalty,
temperature=new_input.temperature,
top_k=new_input.top_k,
top_p=new_input.top_p,
)
if new_input.streaming:

async def stream_generator():
chat_response = ""
async for text in text_generation:
stream_gen_time.append(time.time() - start)
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ SearchedDoc ] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ SearchedDoc ] stream response: {chat_response}")
statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0])
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
statistics_dict["opea_service@llm_tgi"].append_latency(time.time() - start, None)
if logflag:
logger.info(text_generation)
return GeneratedDoc(text=text_generation, prompt=new_input.query)

elif isinstance(input, LLMParamsDoc):
if logflag:
logger.info("[ LLMParamsDoc ] input from rerank microservice")
prompt = input.query
if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.query)
else:
logger.info(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']")
logger.info(
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
Expand All @@ -84,10 +135,10 @@ async def stream_generator():
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[llm - chat_stream] chunk:{chunk_repr}")
logger.info(f"[ LLMParamsDoc ] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[llm - chat_stream] stream response: {chat_response}")
logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}")
statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0])
yield "data: [DONE]\n\n"

Expand All @@ -99,6 +150,8 @@ async def stream_generator():
return GeneratedDoc(text=text_generation, prompt=input.query)

else:
if logflag:
logger.info("[ ChatCompletionRequest ] input in opea format")
client = OpenAI(
api_key="EMPTY",
base_url=llm_endpoint + "/v1",
Expand All @@ -113,7 +166,7 @@ async def stream_generator():
prompt = prompt_template.format(question=input.messages)
else:
logger.info(
f"{prompt_template} not used, we only support 2 input variables ['question', 'context']"
f"[ ChatCompletionRequest ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
Expand Down Expand Up @@ -152,7 +205,9 @@ async def stream_generator():
if input_variables == ["context"]:
system_prompt = prompt_template.format(context="\n".join(input.documents))
else:
logger.info(f"{prompt_template} not used, only support 1 input variables ['context']")
logger.info(
f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']"
)

input.messages.insert(0, {"role": "system", "content": system_prompt})

Expand Down