Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align parameters for "max_token, repetition_penalty,presence_penalty,frequency_penalty" #608

Merged
merged 33 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
56e3791
align max_tokens
XinyaoWa Sep 4, 2024
ddea461
aligin repetition_penalty
XinyaoWa Sep 4, 2024
4a19aaa
Merge branch 'main' into align_param
XinyaoWa Sep 4, 2024
3249b55
Merge branch 'main' into align_param
kevinintel Sep 8, 2024
19d63aa
fix conflict
XinyaoWa Sep 12, 2024
2f1e157
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 12, 2024
badd896
fix bug
XinyaoWa Sep 12, 2024
2b08614
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 13, 2024
ab09c7a
align penalty parameters
XinyaoWa Sep 13, 2024
02bb1f1
fix bug
XinyaoWa Sep 13, 2024
ce32dab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
1b5c53d
Merge remote-tracking branch 'origin/align_param' into align_param
XinyaoWa Sep 13, 2024
3acd922
fix bug
XinyaoWa Sep 13, 2024
2a6f15f
fix bug
XinyaoWa Sep 13, 2024
4f5d3b3
fix bug
XinyaoWa Sep 13, 2024
8c7adf1
fix bug
XinyaoWa Sep 13, 2024
9ff5eef
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 13, 2024
5f15e0a
merge conflict
XinyaoWa Sep 14, 2024
bce4f77
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
3fe4d96
align max_tokens
XinyaoWa Sep 18, 2024
9ceed11
fix bug
XinyaoWa Sep 18, 2024
e4a1826
fix bug
XinyaoWa Sep 18, 2024
f0231c5
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
b0a8e97
debug
XinyaoWa Sep 18, 2024
dcc6c28
Merge remote-tracking branch 'origin/main' into align_param
XinyaoWa Sep 18, 2024
9513e6f
debug
XinyaoWa Sep 18, 2024
8923c56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
2d766b0
fix conflict
XinyaoWa Sep 18, 2024
9af395b
fix langchain version bug
XinyaoWa Sep 18, 2024
daec3ed
fix langchain version bug
XinyaoWa Sep 18, 2024
9128f0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
0adfbcb
Merge remote-tracking branch 'origin/align_param' into align_param
XinyaoWa Sep 18, 2024
ec6ab0e
Merge branch 'main' into align_param
lvliang-intel Sep 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down Expand Up @@ -214,11 +216,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -350,11 +354,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -399,11 +405,13 @@ async def handle_request(self, request: Request):
chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand All @@ -428,11 +436,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -472,11 +482,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -520,7 +532,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -569,7 +583,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -758,7 +774,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down
4 changes: 3 additions & 1 deletion comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,9 @@ class AudioChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = 1024
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.03
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None


Expand Down Expand Up @@ -345,6 +346,7 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None
Expand Down
6 changes: 6 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ class RerankedDoc(BaseDoc):
class LLMParamsDoc(BaseDoc):
model: Optional[str] = None # for openai and ollama
query: str
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down Expand Up @@ -179,11 +182,14 @@ def chat_template_must_contain_variables(cls, v):


class LLMParams(BaseDoc):
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down
2 changes: 1 addition & 1 deletion comps/llms/faq-generation/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
Expand Down
3 changes: 3 additions & 0 deletions comps/llms/faq-generation/tgi/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ docarray[full]
fastapi
huggingface_hub
langchain
langchain-huggingface
langchain-openai
langchain_community
langchainhub
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
Expand Down
6 changes: 3 additions & 3 deletions comps/llms/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ curl http://${your_ip}:8008/v1/chat/completions \

### 3.3 Consume LLM Service

You can set the following model parameters according to your actual needs, such as `max_new_tokens`, `streaming`.
You can set the following model parameters according to your actual needs, such as `max_tokens`, `streaming`.

The `streaming` parameter determines the format of the data returned by the API. It will return text string with `streaming=false`, return text streaming flow with `streaming=true`.

Expand All @@ -385,7 +385,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"query":"What is Deep Learning?",
"max_new_tokens":17,
"max_tokens":17,
"top_k":10,
"top_p":0.95,
"typical_p":0.95,
Expand All @@ -401,7 +401,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"query":"What is Deep Learning?",
"max_new_tokens":17,
"max_tokens":17,
"top_k":10,
"top_p":0.95,
"typical_p":0.95,
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ollama/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ docker run --network host -e http_proxy=$http_proxy -e https_proxy=$https_proxy
## Consume the Ollama Microservice

```bash
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_new_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
```
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ollama/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def llm_generate(input: LLMParamsDoc):
ollama = Ollama(
base_url=ollama_endpoint,
model=input.model if input.model else model_name,
num_predict=input.max_new_tokens,
num_predict=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
temperature=input.temperature,
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/predictionguard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ curl -X POST http://localhost:9000/v1/chat/completions \
-d '{
"model": "Hermes-2-Pro-Llama-3-8B",
"query": "Tell me a joke.",
"max_new_tokens": 100,
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
Expand All @@ -45,7 +45,7 @@ curl -N -X POST http://localhost:9000/v1/chat/completions \
-d '{
"model": "Hermes-2-Pro-Llama-3-8B",
"query": "Tell me a joke.",
"max_new_tokens": 100,
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def stream_generator():
for res in client.chat.completions.create(
model=input.model,
messages=messages,
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p,
top_k=input.top_k,
Expand All @@ -69,7 +69,7 @@ async def stream_generator():
response = client.chat.completions.create(
model=input.model,
messages=messages,
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p,
top_k=input.top_k,
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ray_serve/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def llm_generate(input: LLMParamsDoc):
openai_api_base=llm_endpoint + "/v1",
model_name=llm_model,
openai_api_key=os.getenv("OPENAI_API_KEY", "not_needed"),
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
streaming=input.streaming,
request_timeout=600,
Expand Down
20 changes: 13 additions & 7 deletions comps/llms/text-generation/tgi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,42 @@ curl http://${your_ip}:9000/v1/health_check\

### 3.2 Consume LLM Service

You can set the following model parameters according to your actual needs, such as `max_new_tokens`, `streaming`.
You can set the following model parameters according to your actual needs, such as `max_tokens`, `streaming`.

The `streaming` parameter determines the format of the data returned by the API. It will return text string with `streaming=false`, return text streaming flow with `streaming=true`.

```bash
# non-streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":false}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":false}' \
-H 'Content-Type: application/json'

# streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-H 'Content-Type: application/json'

# custom chat template
# consume with SearchedDoc
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

# consume with SearchedDoc
For parameters in above modes, please refer to [HuggingFace InferenceClient API](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation) (except we rename 'max_new_tokens' to 'max_tokens')

```bash
# custom chat template
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"presence_penalty":1.03", frequency_penalty":0.0, "streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'
```

For parameters in Chat mode, please refer to [OpenAI API](https://platform.openai.com/docs/api-reference/chat/create)

### 4. Validated Model

| Model | TGI |
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche
text_generation = await llm.text_generation(
prompt=prompt,
stream=new_input.streaming,
max_new_tokens=new_input.max_new_tokens,
max_new_tokens=new_input.max_tokens,
repetition_penalty=new_input.repetition_penalty,
temperature=new_input.temperature,
top_k=new_input.top_k,
Expand Down Expand Up @@ -119,7 +119,7 @@ async def stream_generator():
text_generation = await llm.text_generation(
prompt=prompt,
stream=input.streaming,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
repetition_penalty=input.repetition_penalty,
temperature=input.temperature,
top_k=input.top_k,
Expand Down
Loading
Loading