Skip to content

Commit

Permalink
serving cli
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Mar 28, 2024
1 parent 34497ea commit 2ed5c93
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 6 deletions.
7 changes: 6 additions & 1 deletion python/mlc_llm/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Entrypoint of all CLI commands from MLC LLM"""

import sys

from mlc_llm.support import logging
Expand All @@ -13,7 +14,7 @@ def main():
parser.add_argument(
"subcommand",
type=str,
choices=["compile", "convert_weight", "gen_config", "chat", "bench"],
choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench"],
help="Subcommand to to run. (choices: %(choices)s)",
)
parsed = parser.parse_args(sys.argv[1:2])
Expand All @@ -33,6 +34,10 @@ def main():
elif parsed.subcommand == "chat":
from mlc_llm.cli import chat as cli

cli.main(sys.argv[2:])
elif parsed.subcommand == "serve":
from mlc_llm.cli import serve as cli

cli.main(sys.argv[2:])
elif parsed.subcommand == "bench":
from mlc_llm.cli import bench as cli
Expand Down
72 changes: 72 additions & 0 deletions python/mlc_llm/cli/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Command line entrypoint of serve."""

import json

from mlc_llm.help import HELP
from mlc_llm.interface.serve import serve
from mlc_llm.support.argparse import ArgumentParser


def main(argv):
"""Parse command line arguments and call `mlc_llm.interface.chat`."""
parser = ArgumentParser("MLC LLM Chat CLI")

parser.add_argument(
"model",
type=str,
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--opt",
type=str,
default="O2",
help=HELP["opt"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
type=str,
default="auto",
help=HELP["device_deploy"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--model-lib-path",
type=str,
default=None,
help=HELP["model_lib_path"] + ' (default: "%(default)s")',
)
# Todo: help
parser.add_argument(
"--max-batch-size",
type=int,
default=80,
help=HELP["max_batch_size"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"]
)
parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"])
parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"])
parser.add_argument("--host", type=str, default="127.0.0.1", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port")
parser.add_argument("--allow-credentials", action="store_true", help="allow credentials")
parser.add_argument("--allow-origins", type=json.loads, default=["*"], help="allowed origins")
parser.add_argument("--allow-methods", type=json.loads, default=["*"], help="allowed methods")
parser.add_argument("--allow-headers", type=json.loads, default=["*"], help="allowed headers")
parsed = parser.parse_args(argv)

serve(
model=parsed.model,
device=parsed.device,
opt=parsed.opt,
model_lib_path=parsed.model_lib_path,
max_batch_size=parsed.max_batch_size,
max_total_sequence_length=parsed.max_total_seq_length,
prefill_chunk_size=parsed.prefill_chunk_size,
enable_tracing=parsed.enable_tracing,
host=parsed.host,
port=parsed.port,
allow_credentials=parsed.allow_credentials,
allow_origins=parsed.allow_origins,
allow_methods=parsed.allow_methods,
allow_headers=parsed.allow_headers,
)
20 changes: 19 additions & 1 deletion python/mlc_llm/help.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Help message for CLI arguments."""

HELP = {
"config": (
"""
Expand Down Expand Up @@ -111,7 +112,7 @@
the number of sinks is 4. This flag subjects to future refactoring.
""".strip(),
"max_batch_size": """
The maximum allowed batch size set for batch prefill/decode function.
The maximum allowed batch size set for the KV cache to concurrently support.
""".strip(),
"""tensor_parallel_shards""": """
Number of shards to split the model into in tensor parallelism multi-gpu inference.
Expand All @@ -138,5 +139,22 @@
""".strip(),
"generate_length": """
The target length of the text generation.
""".strip(),
"max_total_sequence_length_serve": """
The KV cache total token capacity, i.e., the maximum total number of tokens that
the KV cache support. This decides the GPU memory size that the KV cache consumes.
If not specified, system will automatically estimate the maximum capacity based
on the vRAM size on GPU.
""".strip(),
"prefill_chunk_size_serve": """
The maximum number of tokens the model passes for prefill each time.
It should not exceed the prefill chunk size in model config.
If not specified, this defaults to the prefill chunk size in model config.
""".strip(),
"enable_tracing_serve": """
Enable Chrome Tracing for the server.
After enabling, you can send POST request to the "debug/dump_event_trace" entrypoint
to get the Chrome Trace. For example,
"curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'"
""".strip(),
}
67 changes: 67 additions & 0 deletions python/mlc_llm/interface/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Python entrypoint of serve."""

import dataclasses
import json
from typing import Any, List, Optional, Union

import fastapi
import uvicorn
from fastapi.middleware.cors import CORSMiddleware

from mlc_llm.serve import async_engine, config
from mlc_llm.serve.server import ServerContext


def serve(
model: str,
device: str,
opt: str,
model_lib_path: Optional[str],
max_batch_size: int,
max_total_sequence_length: Optional[int],
prefill_chunk_size: Optional[int],
enable_tracing: bool,
host: str,
port: int,
allow_credentials: bool,
allow_origins: Any,
allow_methods: Any,
allow_headers: Any,
):

# Initialize model loading info and KV cache config
# Todo: JIT
model_info = async_engine.ModelInfo(
model=model,
model_lib_path=model_lib_path,
device=device,
)
kv_cache_config = config.KVCacheConfig(
max_num_sequence=max_batch_size,
max_total_sequence_length=max_total_sequence_length,
prefill_chunk_size=prefill_chunk_size,
)
# Create engine and start the background loop
engine = async_engine.AsyncThreadedEngine(
model_info, kv_cache_config, enable_tracing=enable_tracing
)

# Todo: context-based
ServerContext.add_model(model, engine)

app = fastapi.FastAPI()
app.add_middleware(
CORSMiddleware,
allow_credentials=allow_credentials,
allow_origins=allow_origins,
allow_methods=allow_methods,
allow_headers=allow_headers,
)

# Include the routers from subdirectories.
# Todo: move out?
from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints

app.include_router(openai_entrypoints.app)
app.include_router(debug_entrypoints.app)
uvicorn.run(app, host=host, port=port, log_level="info")
7 changes: 3 additions & 4 deletions python/mlc_llm/serve/server/popen_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__( # pylint: disable=too-many-arguments
host: str = "127.0.0.1",
port: int = 8000,
) -> None:
"""Please check out `python/mlc_llm/serve/server/__main__.py`
for the server arguments."""
# Todo
"""Please check out `python/mlc_llm/cli/serve.py` for the server arguments."""
self.model = model
self.model_lib_path = model_lib_path
self.device = device
Expand All @@ -43,8 +43,7 @@ def start(self) -> None:
Wait until the server becomes ready before return.
"""
cmd = [sys.executable]
cmd += ["-m", "mlc_llm.serve.server"]
cmd += ["--model", self.model]
cmd += ["-m", "mlc_llm", "serve", self.model]
cmd += ["--model-lib-path", self.model_lib_path]
cmd += ["--device", self.device]
cmd += ["--max-batch-size", str(self.max_batch_size)]
Expand Down

0 comments on commit 2ed5c93

Please sign in to comment.