diff --git a/litgpt/api.py b/litgpt/api.py index c1e61b2a66..5996700a62 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -3,6 +3,7 @@ # This file implements the LitGPT Python API from pathlib import Path import sys +import time from typing import Any, List, Literal, Optional, Union import torch @@ -228,9 +229,6 @@ def distribute( if generate_strategy == "sequential" and accelerator not in ("cuda", "gpu"): raise NotImplementedError("generate_strategy='sequential' is only supported for accelerator='cuda'|'gpu.") - #if generate_strategy == "sequential" and init != "pretrained": - # raise NotImplementedError("generate_strategy='sequential' is only supported for init='pretrained'.") - num_devices = calculate_number_of_devices(devices) if generate_strategy is None and num_devices > 1: @@ -353,9 +351,7 @@ def generate( We plan to resolve this in the future. """ assert self.model is not None - - prompt = self.prompt_style.apply(prompt) - input_ids = self.preprocessor.encode(prompt) + input_ids = self._text_to_token_ids(prompt) prompt_length = input_ids.size(0) max_returned_tokens = prompt_length + max_new_tokens @@ -389,7 +385,7 @@ def iterator(): yield from outputs else: for output in outputs: - yield self.preprocessor.tokenizer.decode(output) + yield self.preprocessor.decode(output) return if stream: @@ -411,7 +407,45 @@ def iterator(): elif return_as_token_ids: return outputs else: - return self.preprocessor.tokenizer.decode(outputs) + return self.preprocessor.decode(outputs) + + def _text_to_token_ids(self, prompt): + """Utility method to convert a prompt text to token IDs""" + prompt = self.prompt_style.apply(prompt) + input_ids = self.preprocessor.encode(prompt) + return input_ids + + def benchmark(self, **kwargs): + """ + A wrapper around the .generate() method to calculate runtime performance. + + Arguments: + kwargs: Keyword arguments that are passed to the .generate() method. + """ + benchmark_dict = {} + + time_to_first_token = None + t0 = time.perf_counter() + outputs = self.generate(**kwargs) + + if kwargs.get("stream", False): + gen_outputs = [] + for e in outputs: + if time_to_first_token is None: + t1 = time.perf_counter() + time_to_first_token = t1 - t0 + gen_outputs.append(e) + outputs = "".join(gen_outputs) + else: + outputs = self.generate(**kwargs, ) + benchmark_dict["Seconds total"] = time.perf_counter() - t0 + benchmark_dict["Seconds to first token"] = time_to_first_token + benchmark_dict["Tokens generated"] = self.preprocessor.encode(outputs).size(0) - self._text_to_token_ids(kwargs.get("prompt")).size(0) + benchmark_dict["Inference speed in tokens/sec"] = benchmark_dict["Tokens generated"] / benchmark_dict["Seconds total"] + if self.fabric.device.type == "cuda": + benchmark_dict["Total GPU memory allocated in GB"] = torch.cuda.max_memory_allocated() / 1e9 + + return outputs, benchmark_dict class Preprocessor: diff --git a/tests/test_api.py b/tests/test_api.py index 8b968db030..bf43773553 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -145,7 +145,6 @@ def test_more_than_1_device_for_sequential_gpu(tmp_path): llm.distribute(devices=2, generate_strategy="sequential") assert isinstance(llm.generate("What do llamas eat?"), str) - with pytest.raises(NotImplementedError, match="Support for multiple devices is currently only implemented for generate_strategy='sequential'."): llm.distribute(devices=2) @@ -199,3 +198,15 @@ def test_invalid_accelerator(tmp_path): ) with pytest.raises(ValueError, match="Invalid accelerator"): llm.distribute(accelerator="invalid") + + +def test_returned_benchmark_dir(tmp_path): + llm = LLM.load( + model="EleutherAI/pythia-14m", + ) + + text, bench_d = llm.benchmark(prompt="hello world") + assert isinstance(bench_d["Inference speed in tokens/sec"], float) + + text, bench_d = llm.benchmark(prompt="hello world", stream=True) + assert isinstance(bench_d["Inference speed in tokens/sec"], float) diff --git a/tutorials/python-api.md b/tutorials/python-api.md index 3944f3c3f6..87860ce34d 100644 --- a/tutorials/python-api.md +++ b/tutorials/python-api.md @@ -122,4 +122,35 @@ print(text) ``` Llamas are herbivores and their diet consists mainly of grasses, plants, and leaves. +``` + +  +## Speed and resource estimates + +Use the `.benchmark()` method to compare the computational performance of different settings. The `.benchmark()` method takes the same arguments as the `.generate()` method. For example, we can estimate the speed and GPU memory consumption as follows (the resulting numbers were obtained on an A10G GPU): + +```python +from litgpt.api import LLM +from pprint import pprint + +llm = LLM.load( + model="microsoft/phi-2", + distribute=None +) + +llm.distribute(fixed_kv_cache_size=500) + +text, bench_d = llm.benchmark(prompt="What do llamas eat?", top_k=1, stream=True) +print(text) +pprint(bench_d) + + +# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized +# digestive system that allows them to efficiently extract nutrients from plant material. + +# {'Inference speed in tokens/sec': 15.687777681894985, +# 'Seconds to first token': 0.5756612900004257, +# 'Seconds total': 1.5935972900006163, +# 'Tokens generated': 25, +# 'Total GPU memory allocated in GB': 11.534106624} ``` \ No newline at end of file