diff --git a/litgpt/api.py b/litgpt/api.py index 6eb0281fea..2ad0bc926f 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -239,6 +239,9 @@ def distribute( "Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'." ) + if precision is None: + precision = get_default_supported_precision(training=False) + plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: diff --git a/litgpt/deploy/serve.py b/litgpt/deploy/serve.py index c86d2abbc5..262aee4252 100644 --- a/litgpt/deploy/serve.py +++ b/litgpt/deploy/serve.py @@ -53,6 +53,10 @@ def setup(self, device: str) -> None: print("Initializing model...") self.llm = LLM.load( model=self.checkpoint_dir, + distribute=None + ) + + self.llm.distribute( accelerator=accelerator, quantize=self.quantize, precision=self.precision diff --git a/tests/test_serve.py b/tests/test_serve.py index 23eefdb8c7..8c29f90486 100644 --- a/tests/test_serve.py +++ b/tests/test_serve.py @@ -6,6 +6,7 @@ import torch import requests import subprocess +from tests.conftest import RunIf import threading import time import yaml @@ -54,3 +55,46 @@ def run_server(): if process: process.kill() server_thread.join() + + +@RunIf(min_cuda_gpus=1) +def test_quantize(tmp_path): + seed_everything(123) + ours_config = Config.from_name("pythia-14m") + download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) + ours_model = GPT(ours_config) + checkpoint_path = tmp_path / "lit_model.pth" + torch.save(ours_model.state_dict(), checkpoint_path) + config_path = tmp_path / "model_config.yaml" + with open(config_path, "w", encoding="utf-8") as fp: + yaml.dump(asdict(ours_config), fp) + + run_command = [ + "litgpt", "serve", tmp_path, "--quantize", "bnb.nf4" + ] + + process = None + + def run_server(): + nonlocal process + try: + process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + stdout, stderr = process.communicate(timeout=10) + except subprocess.TimeoutExpired: + print('Server start-up timeout expired') + + server_thread = threading.Thread(target=run_server) + server_thread.start() + + time.sleep(10) + + try: + response = requests.get("http://127.0.0.1:8000") + print(response.status_code) + assert response.status_code == 200, "Server did not respond as expected." + finally: + if process: + process.kill() + server_thread.join()