Skip to content

Commit

Permalink
Add support for using dynamic threads (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaby authored Sep 20, 2023
1 parent b4190e5 commit 646daa6
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 43 deletions.
1 change: 0 additions & 1 deletion api/src/serge/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class ChatParameters(BaseModel):
# logits_all: bool
# vocab_only: bool
# use_mlock: bool
n_threads: int
# n_batch: int
last_n_tokens_size: int
max_tokens: int
Expand Down
5 changes: 3 additions & 2 deletions api/src/serge/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from typing import Optional
from fastapi import APIRouter
from langchain.memory import RedisChatMessageHistory
Expand Down Expand Up @@ -28,7 +30,6 @@ async def create_new_chat(
repeat_last_n: int = 64,
repeat_penalty: float = 1.3,
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
n_threads: int = 4,
):
try:
client = Llama(
Expand All @@ -51,7 +52,7 @@ async def create_new_chat(
n_gpu_layers=gpu_layers,
last_n_tokens_size=repeat_last_n,
repeat_penalty=repeat_penalty,
n_threads=n_threads,
n_threads=len(os.sched_getaffinity(0)),
init_prompt=init_prompt,
)
# create the chat
Expand Down
1 change: 0 additions & 1 deletion api/src/serge/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def _default_params(self) -> dict[str, Any]:
"stop_sequences": self.stop_sequences,
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
"n_threads": self.n_threads,
"n_ctx": self.n_ctx,
"n_gpu_layers": self.n_gpu_layers,
"n_parts": self.n_parts,
Expand Down
2 changes: 1 addition & 1 deletion web/src/routes/+layout.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
`/api/chat/?model=${dataCht.params.model_path}&temperature=${dataCht.params.temperature}&top_k=${dataCht.params.top_k}` +
`&top_p=${dataCht.params.top_p}&max_length=${dataCht.params.max_tokens}&context_window=${dataCht.params.n_ctx}` +
`&repeat_last_n=${dataCht.params.last_n_tokens_size}&repeat_penalty=${dataCht.params.repeat_penalty}` +
`&n_threads=${dataCht.params.n_threads}&init_prompt=${dataCht.history[0].data.content}` +
`&init_prompt=${dataCht.history[0].data.content}` +
`&gpu_layers=${dataCht.params.n_gpu_layers}`,
{
Expand Down
15 changes: 0 additions & 15 deletions web/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
let init_prompt =
"Below is an instruction that describes a task. Write a response that appropriately completes the request.";
let n_threads = 4;
let context_window = 2048;
let gpu_layers = 0;
Expand Down Expand Up @@ -226,20 +225,6 @@
{/each}
</select>
</div>
<div
class="tooltip flex flex-col"
data-tip="Number of threads to run LLaMA on."
>
<label for="n_threads" class="label-text pb-1">n_threads</label>
<input
class="input-bordered input w-full max-w-xs"
name="n_threads"
type="number"
bind:value={n_threads}
min="0"
max="64"
/>
</div>
<div
class="tooltip flex flex-col"
data-tip="The weight of the penalty to avoid repeating the last repeat_last_n tokens."
Expand Down
23 changes: 1 addition & 22 deletions web/src/routes/chat/[id]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
`/api/chat/?model=${data.chat.params.model_path}&temperature=${data.chat.params.temperature}&top_k=${data.chat.params.top_k}` +
`&top_p=${data.chat.params.top_p}&max_length=${data.chat.params.max_tokens}&context_window=${data.chat.params.n_ctx}` +
`&repeat_last_n=${data.chat.params.last_n_tokens_size}&repeat_penalty=${data.chat.params.repeat_penalty}` +
`&n_threads=${data.chat.params.n_threads}&init_prompt=${data.chat.history[0].data.content}` +
`&init_prompt=${data.chat.history[0].data.content}` +
`&gpu_layers=${data.chat.params.n_gpu_layers}`,
{
Expand Down Expand Up @@ -337,27 +337,6 @@
{data.chat.params.n_ctx}/{data.chat.params.max_tokens}
</span>
</div>
{#if data.chat.params.n_threads > 0}
<div class="pl-4 hidden sm:flex flex-row items-center justify-center">
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8.25 3v1.5M4.5 8.25H3m18 0h-1.5M4.5 12H3m18 0h-1.5m-15 3.75H3m18 0h-1.5M8.25 19.5V21M12 3v1.5m0 15V21m3.75-18v1.5m0 15V21m-9-1.5h10.5a2.25 2.25 0 002.25-2.25V6.75a2.25 2.25 0 00-2.25-2.25H6.75A2.25 2.25 0 004.5 6.75v10.5a2.25 2.25 0 002.25 2.25zm.75-12h9v9h-9v-9z"
/>
</svg>
<span class="ml-2 inline-block text-center text-sm font-semibold">
{data.chat.params.n_threads}
</span>
</div>
{/if}
{#if data.chat.params.n_gpu_layers > 0}
<div class="pl-4 hidden sm:flex flex-row items-center justify-center">
<svg
Expand Down
1 change: 0 additions & 1 deletion web/src/routes/chat/[id]/+page.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ interface Params {
model_path: string;
n_ctx: number;
n_gpu_layers: number;
n_threads: number;
last_n_tokens_size: number;
max_tokens: number;
temperature: number;
Expand Down

0 comments on commit 646daa6

Please sign in to comment.