Skip to content

Commit

Permalink
server : add more test cases (#10569)
Browse files Browse the repository at this point in the history
* server : add split model test

* add test speculative

* add invalid cases
  • Loading branch information
ngxson authored Nov 29, 2024
1 parent 3a8e9af commit b782e5c
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 1 deletion.
14 changes: 14 additions & 0 deletions examples/server/tests/unit/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,17 @@ def test_server_models():
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias

def test_load_split_model():
global server
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])
19 changes: 19 additions & 0 deletions examples/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert res.status_code != 200
assert "error" in res.body


@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body
22 changes: 22 additions & 0 deletions examples/server/tests/unit/test_infill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def create_server():
global server
server = ServerPreset.tinyllama_infill()


def test_infill_without_input_extra():
global server
server.start()
Expand All @@ -19,6 +20,7 @@ def test_infill_without_input_extra():
assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])


def test_infill_with_input_extra():
global server
server.start()
Expand All @@ -33,3 +35,23 @@ def test_infill_with_input_extra():
})
assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])


@pytest.mark.parametrize("input_extra", [
{},
{"filename": "ok"},
{"filename": 123},
{"filename": 123, "text": "abc"},
{"filename": 123, "text": 456},
])
def test_invalid_input_extra_req(input_extra):
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 400
assert "error" in res.body
17 changes: 17 additions & 0 deletions examples/server/tests/unit/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,20 @@ def test_rerank():
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3


@pytest.mark.parametrize("documents", [
[],
None,
123,
[1, 2, 3],
])
def test_invalid_rerank_req(documents):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": documents,
})
assert res.status_code == 400
assert "error" in res.body
103 changes: 103 additions & 0 deletions examples/server/tests/unit/test_speculative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from utils import *

# We use a F16 MOE gguf as main model, and q4_0 as draft model

server = ServerPreset.stories15m_moe()

MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"

def create_server():
global server
server = ServerPreset.stories15m_moe()
# download draft model file if needed
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
model_draft_file = f'../../../{file_name}'
if not os.path.exists(model_draft_file):
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
with open(model_draft_file, 'wb') as f:
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
print(f"Done downloading draft model file")
# set default values
server.model_draft = model_draft_file
server.draft_min = 4
server.draft_max = 8


@pytest.fixture(scope="module", autouse=True)
def fixture_create_server():
return create_server()


def test_with_and_without_draft():
global server
server.model_draft = None # disable draft model
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_no_draft = res.body["content"]
server.stop()

# create new server with draft model
create_server()
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_draft = res.body["content"]

assert content_no_draft == content_draft


def test_different_draft_min_draft_max():
global server
test_values = [
(1, 2),
(1, 4),
(4, 8),
(4, 12),
(8, 16),
]
last_content = None
for draft_min, draft_max in test_values:
server.stop()
server.draft_min = draft_min
server.draft_max = draft_max
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
if last_content is not None:
assert last_content == res.body["content"]
last_content = res.body["content"]


@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})))
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])
12 changes: 11 additions & 1 deletion examples/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ServerProcess:
model_alias: str | None = None
model_url: str | None = None
model_file: str | None = None
model_draft: str | None = None
n_threads: int | None = None
n_gpu_layer: int | None = None
n_batch: int | None = None
Expand All @@ -68,6 +69,8 @@ class ServerProcess:
response_format: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None

# session variables
process: subprocess.Popen | None = None
Expand Down Expand Up @@ -102,6 +105,8 @@ def start(self, timeout_seconds: int = 10) -> None:
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
if self.model_draft:
server_args.extend(["--model-draft", self.model_draft])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
Expand Down Expand Up @@ -147,6 +152,10 @@ def start(self, timeout_seconds: int = 10) -> None:
server_args.extend(["--no-context-shift"])
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
server_args.extend(["--draft-max", self.draft_max])
if self.draft_min:
server_args.extend(["--draft-min", self.draft_min])

args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
Expand Down Expand Up @@ -185,7 +194,8 @@ def start(self, timeout_seconds: int = 10) -> None:
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")

def stop(self) -> None:
server_instances.remove(self)
if self in server_instances:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
Expand Down

0 comments on commit b782e5c

Please sign in to comment.