Skip to content

Commit

Permalink
Optionally return benchmark info in Python API (#1660)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Aug 7, 2024
1 parent 4558463 commit b4fd601
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 9 deletions.
50 changes: 42 additions & 8 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file implements the LitGPT Python API
from pathlib import Path
import sys
import time
from typing import Any, List, Literal, Optional, Union

import torch
Expand Down Expand Up @@ -228,9 +229,6 @@ def distribute(
if generate_strategy == "sequential" and accelerator not in ("cuda", "gpu"):
raise NotImplementedError("generate_strategy='sequential' is only supported for accelerator='cuda'|'gpu.")

#if generate_strategy == "sequential" and init != "pretrained":
# raise NotImplementedError("generate_strategy='sequential' is only supported for init='pretrained'.")

num_devices = calculate_number_of_devices(devices)

if generate_strategy is None and num_devices > 1:
Expand Down Expand Up @@ -353,9 +351,7 @@ def generate(
We plan to resolve this in the future.
"""
assert self.model is not None

prompt = self.prompt_style.apply(prompt)
input_ids = self.preprocessor.encode(prompt)
input_ids = self._text_to_token_ids(prompt)
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens

Expand Down Expand Up @@ -389,7 +385,7 @@ def iterator():
yield from outputs
else:
for output in outputs:
yield self.preprocessor.tokenizer.decode(output)
yield self.preprocessor.decode(output)
return

if stream:
Expand All @@ -411,7 +407,45 @@ def iterator():
elif return_as_token_ids:
return outputs
else:
return self.preprocessor.tokenizer.decode(outputs)
return self.preprocessor.decode(outputs)

def _text_to_token_ids(self, prompt):
"""Utility method to convert a prompt text to token IDs"""
prompt = self.prompt_style.apply(prompt)
input_ids = self.preprocessor.encode(prompt)
return input_ids

def benchmark(self, **kwargs):
"""
A wrapper around the .generate() method to calculate runtime performance.
Arguments:
kwargs: Keyword arguments that are passed to the .generate() method.
"""
benchmark_dict = {}

time_to_first_token = None
t0 = time.perf_counter()
outputs = self.generate(**kwargs)

if kwargs.get("stream", False):
gen_outputs = []
for e in outputs:
if time_to_first_token is None:
t1 = time.perf_counter()
time_to_first_token = t1 - t0
gen_outputs.append(e)
outputs = "".join(gen_outputs)
else:
outputs = self.generate(**kwargs, )
benchmark_dict["Seconds total"] = time.perf_counter() - t0
benchmark_dict["Seconds to first token"] = time_to_first_token
benchmark_dict["Tokens generated"] = self.preprocessor.encode(outputs).size(0) - self._text_to_token_ids(kwargs.get("prompt")).size(0)
benchmark_dict["Inference speed in tokens/sec"] = benchmark_dict["Tokens generated"] / benchmark_dict["Seconds total"]
if self.fabric.device.type == "cuda":
benchmark_dict["Total GPU memory allocated in GB"] = torch.cuda.max_memory_allocated() / 1e9

return outputs, benchmark_dict


class Preprocessor:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def test_more_than_1_device_for_sequential_gpu(tmp_path):
llm.distribute(devices=2, generate_strategy="sequential")
assert isinstance(llm.generate("What do llamas eat?"), str)


with pytest.raises(NotImplementedError, match="Support for multiple devices is currently only implemented for generate_strategy='sequential'."):
llm.distribute(devices=2)

Expand Down Expand Up @@ -199,3 +198,15 @@ def test_invalid_accelerator(tmp_path):
)
with pytest.raises(ValueError, match="Invalid accelerator"):
llm.distribute(accelerator="invalid")


def test_returned_benchmark_dir(tmp_path):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)

text, bench_d = llm.benchmark(prompt="hello world")
assert isinstance(bench_d["Inference speed in tokens/sec"], float)

text, bench_d = llm.benchmark(prompt="hello world", stream=True)
assert isinstance(bench_d["Inference speed in tokens/sec"], float)
31 changes: 31 additions & 0 deletions tutorials/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,35 @@ print(text)

```
Llamas are herbivores and their diet consists mainly of grasses, plants, and leaves.
```

 
## Speed and resource estimates

Use the `.benchmark()` method to compare the computational performance of different settings. The `.benchmark()` method takes the same arguments as the `.generate()` method. For example, we can estimate the speed and GPU memory consumption as follows (the resulting numbers were obtained on an A10G GPU):

```python
from litgpt.api import LLM
from pprint import pprint

llm = LLM.load(
model="microsoft/phi-2",
distribute=None
)

llm.distribute(fixed_kv_cache_size=500)

text, bench_d = llm.benchmark(prompt="What do llamas eat?", top_k=1, stream=True)
print(text)
pprint(bench_d)


# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized
# digestive system that allows them to efficiently extract nutrients from plant material.

# {'Inference speed in tokens/sec': 15.687777681894985,
# 'Seconds to first token': 0.5756612900004257,
# 'Seconds total': 1.5935972900006163,
# 'Tokens generated': 25,
# 'Total GPU memory allocated in GB': 11.534106624}
```

0 comments on commit b4fd601

Please sign in to comment.