Skip to content

Commit

Permalink
Support the refactored API in litgpt serve (#1668)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Aug 12, 2024
1 parent 3eab461 commit 2433eaf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
3 changes: 3 additions & 0 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def distribute(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."
)

if precision is None:
precision = get_default_supported_precision(training=False)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
Expand Down
4 changes: 4 additions & 0 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def setup(self, device: str) -> None:
print("Initializing model...")
self.llm = LLM.load(
model=self.checkpoint_dir,
distribute=None
)

self.llm.distribute(
accelerator=accelerator,
quantize=self.quantize,
precision=self.precision
Expand Down
44 changes: 44 additions & 0 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import requests
import subprocess
from tests.conftest import RunIf
import threading
import time
import yaml
Expand Down Expand Up @@ -54,3 +55,46 @@ def run_server():
if process:
process.kill()
server_thread.join()


@RunIf(min_cuda_gpus=1)
def test_quantize(tmp_path):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)

run_command = [
"litgpt", "serve", tmp_path, "--quantize", "bnb.nf4"
]

process = None

def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate(timeout=10)
except subprocess.TimeoutExpired:
print('Server start-up timeout expired')

server_thread = threading.Thread(target=run_server)
server_thread.start()

time.sleep(10)

try:
response = requests.get("http://127.0.0.1:8000")
print(response.status_code)
assert response.status_code == 200, "Server did not respond as expected."
finally:
if process:
process.kill()
server_thread.join()

0 comments on commit 2433eaf

Please sign in to comment.