diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index 84db5ca1ca192..d82d54a5a6f47 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -32,3 +32,17 @@ def test_server_models(): assert res.status_code == 200 assert len(res.body["data"]) == 1 assert res.body["data"][0]["id"] == server.model_alias + +def test_load_split_model(): + global server + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" + server.model_alias = "tinyllama-split" + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 16, + "prompt": "Hello", + "temperature": 0.0, + }) + assert res.status_code == 200 + assert match_regex("(little|girl)+", res.body["content"]) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index d7aeb288d45cc..1048d6fcaf500 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert res.status_code != 200 assert "error" in res.body + +@pytest.mark.parametrize("messages", [ + None, + "string", + [123], + [{}], + [{"role": 123}], + [{"role": "system", "content": 123}], + # [{"content": "hello"}], # TODO: should not be a valid case + [{"role": "system", "content": "test"}, {}], +]) +def test_invalid_chat_completion_req(messages): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "messages": messages, + }) + assert res.status_code == 400 or res.status_code == 500 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index 38ce6c42954ed..6a6d40a1cbc8b 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -8,6 +8,7 @@ def create_server(): global server server = ServerPreset.tinyllama_infill() + def test_infill_without_input_extra(): global server server.start() @@ -19,6 +20,7 @@ def test_infill_without_input_extra(): assert res.status_code == 200 assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) + def test_infill_with_input_extra(): global server server.start() @@ -33,3 +35,23 @@ def test_infill_with_input_extra(): }) assert res.status_code == 200 assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) + + +@pytest.mark.parametrize("input_extra", [ + {}, + {"filename": "ok"}, + {"filename": 123}, + {"filename": 123, "text": "abc"}, + {"filename": 123, "text": 456}, +]) +def test_invalid_input_extra_req(input_extra): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "prompt": "Complete this", + "input_extra": [input_extra], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 400 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py index 3a49fd3ac6bdf..189bc4c962329 100644 --- a/examples/server/tests/unit/test_rerank.py +++ b/examples/server/tests/unit/test_rerank.py @@ -36,3 +36,20 @@ def test_rerank(): assert most_relevant["relevance_score"] > least_relevant["relevance_score"] assert most_relevant["index"] == 2 assert least_relevant["index"] == 3 + + +@pytest.mark.parametrize("documents", [ + [], + None, + 123, + [1, 2, 3], +]) +def test_invalid_rerank_req(documents): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": documents, + }) + assert res.status_code == 400 + assert "error" in res.body diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py new file mode 100644 index 0000000000000..982d6abb45f5f --- /dev/null +++ b/examples/server/tests/unit/test_speculative.py @@ -0,0 +1,103 @@ +import pytest +from utils import * + +# We use a F16 MOE gguf as main model, and q4_0 as draft model + +server = ServerPreset.stories15m_moe() + +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" + +def create_server(): + global server + server = ServerPreset.stories15m_moe() + # download draft model file if needed + file_name = MODEL_DRAFT_FILE_URL.split('/').pop() + model_draft_file = f'../../../{file_name}' + if not os.path.exists(model_draft_file): + print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") + with open(model_draft_file, 'wb') as f: + f.write(requests.get(MODEL_DRAFT_FILE_URL).content) + print(f"Done downloading draft model file") + # set default values + server.model_draft = model_draft_file + server.draft_min = 4 + server.draft_max = 8 + + +@pytest.fixture(scope="module", autouse=True) +def fixture_create_server(): + return create_server() + + +def test_with_and_without_draft(): + global server + server.model_draft = None # disable draft model + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_no_draft = res.body["content"] + server.stop() + + # create new server with draft model + create_server() + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_draft = res.body["content"] + + assert content_no_draft == content_draft + + +def test_different_draft_min_draft_max(): + global server + test_values = [ + (1, 2), + (1, 4), + (4, 8), + (4, 12), + (8, 16), + ] + last_content = None + for draft_min, draft_max in test_values: + server.stop() + server.draft_min = draft_min + server.draft_max = draft_max + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + if last_content is not None: + assert last_content == res.body["content"] + last_content = res.body["content"] + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 2), + (2, 2), +]) +def test_multi_requests_parallel(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.start() + tasks = [] + for _ in range(n_requests): + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }))) + results = parallel_function_calls(tasks) + for res in results: + assert res.status_code == 200 + assert match_regex("(wise|kind|owl|answer)+", res.body["content"]) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a831f113f4161..e17a05ff6902a 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -46,6 +46,7 @@ class ServerProcess: model_alias: str | None = None model_url: str | None = None model_file: str | None = None + model_draft: str | None = None n_threads: int | None = None n_gpu_layer: int | None = None n_batch: int | None = None @@ -68,6 +69,8 @@ class ServerProcess: response_format: str | None = None lora_files: List[str] | None = None disable_ctx_shift: int | None = False + draft_min: int | None = None + draft_max: int | None = None # session variables process: subprocess.Popen | None = None @@ -102,6 +105,8 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--model", self.model_file]) if self.model_url: server_args.extend(["--model-url", self.model_url]) + if self.model_draft: + server_args.extend(["--model-draft", self.model_draft]) if self.model_hf_repo: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: @@ -147,6 +152,10 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--no-context-shift"]) if self.api_key: server_args.extend(["--api-key", self.api_key]) + if self.draft_max: + server_args.extend(["--draft-max", self.draft_max]) + if self.draft_min: + server_args.extend(["--draft-min", self.draft_min]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") @@ -185,7 +194,8 @@ def start(self, timeout_seconds: int = 10) -> None: raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") def stop(self) -> None: - server_instances.remove(self) + if self in server_instances: + server_instances.remove(self) if self.process: print(f"Stopping server with pid={self.process.pid}") self.process.kill()