Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align parameters for "max_token, repetition_penalty,presence_penalty,frequency_penalty" #608

Merged
merged 33 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
56e3791
align max_tokens
XinyaoWa Sep 4, 2024
ddea461
aligin repetition_penalty
XinyaoWa Sep 4, 2024
4a19aaa
Merge branch 'main' into align_param
XinyaoWa Sep 4, 2024
3249b55
Merge branch 'main' into align_param
kevinintel Sep 8, 2024
19d63aa
fix conflict
XinyaoWa Sep 12, 2024
2f1e157
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 12, 2024
badd896
fix bug
XinyaoWa Sep 12, 2024
2b08614
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 13, 2024
ab09c7a
align penalty parameters
XinyaoWa Sep 13, 2024
02bb1f1
fix bug
XinyaoWa Sep 13, 2024
ce32dab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
1b5c53d
Merge remote-tracking branch 'origin/align_param' into align_param
XinyaoWa Sep 13, 2024
3acd922
fix bug
XinyaoWa Sep 13, 2024
2a6f15f
fix bug
XinyaoWa Sep 13, 2024
4f5d3b3
fix bug
XinyaoWa Sep 13, 2024
8c7adf1
fix bug
XinyaoWa Sep 13, 2024
9ff5eef
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 13, 2024
5f15e0a
merge conflict
XinyaoWa Sep 14, 2024
bce4f77
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
3fe4d96
align max_tokens
XinyaoWa Sep 18, 2024
9ceed11
fix bug
XinyaoWa Sep 18, 2024
e4a1826
fix bug
XinyaoWa Sep 18, 2024
f0231c5
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
b0a8e97
debug
XinyaoWa Sep 18, 2024
dcc6c28
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
9513e6f
debug
XinyaoWa Sep 18, 2024
8923c56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
2d766b0
fix conflict
XinyaoWa Sep 18, 2024
9af395b
fix langchain version bug
XinyaoWa Sep 18, 2024
daec3ed
fix langchain version bug
XinyaoWa Sep 18, 2024
9128f0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
0adfbcb
Merge remote-tracking branch 'origin/align_param' into align_param
XinyaoWa Sep 18, 2024
ec6ab0e
Merge branch 'main' into align_param
lvliang-intel Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down Expand Up @@ -214,11 +216,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -350,11 +354,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -398,12 +404,14 @@ async def handle_request(self, request: Request):

chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
# relatively lower max_new_tokens for audio conversation
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand All @@ -428,11 +436,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -472,11 +482,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -516,11 +528,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt, images = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -565,11 +579,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -754,11 +770,13 @@ async def handle_request(self, request: Request):
initial_inputs = {"text": prompt}

parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_new_tokens=chat_request.max_new_tokens if chat_request.max_new_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down
16 changes: 9 additions & 7 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
class TokenCheckRequestItem(BaseModel):
model: str
prompt: str
max_tokens: int
max_new_tokens: int


class TokenCheckRequest(BaseModel):
Expand Down Expand Up @@ -160,7 +160,7 @@
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = 1024 # use https://platform.openai.com/docs/api-reference/completions/create
XinyaoWa marked this conversation as resolved.
Show resolved Hide resolved
max_new_tokens: Optional[int] = 1024 # use https://platform.openai.com/docs/api-reference/completions/create
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
Expand Down Expand Up @@ -282,11 +282,12 @@
top_p: Optional[float] = 0.95
top_k: Optional[int] = 10
n: Optional[int] = 1
max_tokens: Optional[int] = 1024
max_new_tokens: Optional[int] = 1024
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.03
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None


Expand Down Expand Up @@ -336,7 +337,7 @@
suffix: Optional[str] = None
temperature: Optional[float] = 0.7
n: Optional[int] = 1
max_tokens: Optional[int] = 16
max_new_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
Expand All @@ -345,6 +346,7 @@
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None
Expand Down Expand Up @@ -497,10 +499,10 @@

def check_requests(request) -> Optional[JSONResponse]:
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
if request.max_new_tokens is not None and request.max_new_tokens <= 0:

Check warning on line 502 in comps/cores/proto/api_protocol.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/proto/api_protocol.py#L502

Added line #L502 was not covered by tests
return create_error_response(
ApiErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
f"{request.max_new_tokens} is less than the minimum of 1 - 'max_new_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
Expand Down
4 changes: 4 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class LLMParamsDoc(BaseDoc):
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down Expand Up @@ -184,6 +186,8 @@ class LLMParams(BaseDoc):
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ curl http://${your_ip}:8008/v1/completions \
-d '{
"model": ${your_hf_llm_model},
"prompt": "What is Deep Learning?",
"max_tokens": 32,
"max_new_tokens": 32,
"temperature": 0
}'
```
Expand All @@ -75,7 +75,7 @@ curl http://${your_ip}:8008/v1/chat/completions \
{"role": "assistant", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Deep Learning?"},
],
"max_tokens": 32,
"max_new_tokens": 32,
"stream": True
}'
```
Expand Down
14 changes: 10 additions & 4 deletions comps/llms/text-generation/tgi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,25 @@ curl http://${your_ip}:9000/v1/chat/completions \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-H 'Content-Type: application/json'

# custom chat template
# consume with SearchedDoc
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

# consume with SearchedDoc
For parameters in above modes, please refer to [HuggingFace InferenceClient API](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation)

```bash
# custom chat template
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"presence_penalty":1.03", frequency_penalty":0.0, "streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'
```

For parameters in Chat mode, please refer to [OpenAI API](https://platform.openai.com/docs/api-reference/chat/create)

### 4. Validated Model

| Model | TGI |
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def stream_generator():
frequency_penalty=input.frequency_penalty,
logit_bias=input.logit_bias,
logprobs=input.logprobs,
max_tokens=input.max_tokens,
max_tokens=input.max_new_tokens,
n=input.n,
presence_penalty=input.presence_penalty,
seed=input.seed,
Expand Down Expand Up @@ -217,7 +217,7 @@ async def stream_generator():
logit_bias=input.logit_bias,
logprobs=input.logprobs,
top_logprobs=input.top_logprobs,
max_tokens=input.max_tokens,
max_tokens=input.max_new_tokens,
n=input.n,
presence_penalty=input.presence_penalty,
response_format=input.response_format,
Expand Down
4 changes: 3 additions & 1 deletion comps/llms/text-generation/vllm/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ User can set the following model parameters according to needs:
# 1. Non-streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":0.95,"temperature":0.01,"streaming":false}' \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":1,"temperature":0.7,"frequency_penalty":0,"presence_penalty":0, "streaming":false}' \
-H 'Content-Type: application/json'

# 2. Streaming mode
Expand All @@ -224,3 +224,5 @@ curl http://${your_ip}:9000/v1/chat/completions \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

For parameters, can refer to [LangChain VLLMOpenAI API](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.vllm.VLLMOpenAI.html)
4 changes: 4 additions & 0 deletions comps/llms/text-generation/vllm/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
model_name=model_name,
top_p=new_input.top_p,
temperature=new_input.temperature,
frequency_penalty=new_input.frequency_penalty,
presence_penalty=new_input.presence_penalty,
streaming=new_input.streaming,
)

Expand Down Expand Up @@ -136,6 +138,8 @@ def stream_generator():
model_name=model_name,
top_p=input.top_p,
temperature=input.temperature,
frequency_penalty=input.frequency_penalty,
presence_penalty=input.presence_penalty,
streaming=input.streaming,
)

Expand Down
2 changes: 1 addition & 1 deletion comps/llms/text-generation/vllm/langchain/query.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ curl http://${your_ip}:8008/v1/completions \
##query microservice
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":0.95,"temperature":0.01,"streaming":false}' \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":1,"temperature":0.7,"frequency_penalty":0,"presence_penalty":0, "streaming":false}' \
-H 'Content-Type: application/json'
4 changes: 3 additions & 1 deletion comps/llms/text-generation/vllm/ray/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ bash ./launch_microservice.sh
```bash
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":false}' \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":1,"temperature":0.7,"frequency_penalty":0,"presence_penalty":0, "streaming":false}' \
-H 'Content-Type: application/json'
```

For parameters, can refer to [LangChain ChatOpenAI API](https://python.langchain.com/v0.2/api_reference/openai/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html)
Loading
Loading