Skip to content

Commit

Permalink
Refactor Python API to introduce new distribute method (part of a lar…
Browse files Browse the repository at this point in the history
…ger refactor for PTL support) (#1657)
  • Loading branch information
rasbt authored Aug 6, 2024
1 parent 6786c42 commit 9722fdb
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 136 deletions.
247 changes: 156 additions & 91 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,25 @@ class LLM:
def __init__(
self,
model: GPT,
tokenizer: Tokenizer,
prompt_style: PromptStyle,
preprocessor=None,
prompt_style: PromptStyle = None,
devices: Union[int, List[int]] = None,
config: Config = None,
checkpoint_dir: Path = None,
fabric: L.Fabric = None,
generate_strategy: Optional[Literal["sequential"]] = None,
kvcache_initialized: bool = False,
kv_cache_initialized: bool = False,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None
) -> None:
self.model = model
self.preprocessor = Preprocessor(tokenizer, device=fabric.device)
self._model = model
self.preprocessor = preprocessor
self.devices = devices
self.prompt_style = prompt_style
self.config = config
self.checkpoint_dir = checkpoint_dir
self.fabric = fabric
self.generate_strategy = generate_strategy
self.kvcache_initialized = kvcache_initialized
self.kv_cache_initialized = kv_cache_initialized
self.fixed_kv_cache_size = fixed_kv_cache_size
self.prev_generated_seq_length = 0

Expand All @@ -57,81 +59,46 @@ def __init__(
Example:
from litgpt.api import LLM
llm = LLM.load("microsoft/phi-2", accelerator="cuda", devices=1)
llm = LLM.load("microsoft/phi-2")
text = llm.generate("What do Llamas eat?", top_k=1)
print(text)
"""

@property
def model(self):
if self._model is None:
raise AttributeError("The model is not initialized yet; use the .distribute() method to initialize the model.")
return self._model

@model.setter
def model(self, content):
self._model = content

@classmethod
def load(
cls,
model: str,
accelerator: Literal["cpu", "cuda", "auto"] = "auto",
devices: Union[int, List[int]] = 1,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[Any] = None,
init: Optional[Literal["pretrained", "random"]] = "pretrained",
tokenizer_dir: Optional[Path] = None,
access_token: Optional[str] = None,
generate_strategy: Optional[Literal["sequential"]] = None,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None
distribute: Optional[Literal["auto"]] = "auto"
) -> "LLM":
"""
Loads the LLM from a local directory or model hub.
Arguments
model: A local path to a directory containing the model weights or a valid model name.
You can get a list of valid model names via the `litgpt download list` command line argument.
accelerator: Which device type to load the model on ("cpu", "gpu", "mps", "cuda", or "auto")
devices: The number of devices (1, 2, etc.) or device IDs (e.g., [0, 2] to use the first and third GPU).
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
For instance, "32-true", "16-mixed", "16-true", "bf16-mixed", "bf16-true".
For more details, see https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
init: If "pretrained" (default), downloads the model from the HF Hub if a local model can't be found at the `model`
directory name; otherwise loads the model from the local directory.
If "random", initializes the `model` with random weights.
tokenizer_dir: An optional tokenizer directory if `model` is not a checkpoint directory, or if a user
wants to use a different tokenizer instead.
access_token:
Optional API token to access models with restrictions when using `init="pretrained"`.
generate_strategy: Whether to use a sequential model generation strategy. The "sequential" settings allows running
models that wouldn't fit in a single card by partitioning the transformer blocks across
all devices and running them sequentially. Sequential generation may be slower but allows using larger models.
Note that sequential generation sets `fixed_kv_cache_size="max_model_supported"`. You can set it to a lower integer
value, `fixed_kv_cache_size=256` to reduce memory memory. The `fixed_kv_cache_size` value determins the maximum number
of tokens that can be returned via `llm.generate(...)`.
fixed_kv_cache_size: If set to an integer value or "max_model_supported" is set, the kv-cache won't be resized dynamically
during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy="sequential`.
Note that the chosen `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned in `llm.generate(...)`.
access_token: Optional API token to access models with restrictions when using `init="pretrained"`.
distribute: If "auto" (default), initializes the model on a single GPU if available and otherwise on the CPU.
To have more control over the model distribution strategy and utilize multiple GPUs, you can set
`llm = LLM.load(..., distribute=None)` and call `llm.distribute(...)` manually.
"""
allowed_accelerators = {"cpu", "gpu", "cuda", "mps", "auto"}
if accelerator not in allowed_accelerators:
raise ValueError(f"Invalid accelerator: {accelerator}. Must be one of {allowed_accelerators}.")

if accelerator == "auto":
if torch.cuda.is_available():
accelerator = "cuda"
elif torch.backends.mps.is_available():
accelerator = "mps"
else:
accelerator = "cpu"

if generate_strategy == "sequential" and accelerator != "cuda":
raise NotImplementedError("generate_strategy='sequential' is only supported for accelerator='cuda'.")

if generate_strategy == "sequential" and init != "pretrained":
raise NotImplementedError("generate_strategy='sequential' is only supported for init='pretrained'.")

num_devices = calculate_number_of_devices(devices)

if generate_strategy is None and num_devices > 1:
raise NotImplementedError(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'."
)

allowed_init = {"pretrained", "random"}

Expand All @@ -153,23 +120,6 @@ def load(
raise ValueError(f"Invalid init option: {init}. Must be one of {allowed_init}")

torch.set_float32_matmul_precision("high")
precision = precision or get_default_supported_precision(training=False)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("The combination of quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

fabric = L.Fabric(
accelerator=accelerator,
devices=1, # Otherwise sequential wouldn't work, see litgpt/generate/sequentially.py
#devices=devices,
precision=precision,
plugins=plugins
)

if tokenizer_dir is not None:
tokenizer_dir = extend_checkpoint_dir(Path(tokenizer_dir))
Expand All @@ -185,29 +135,144 @@ def load(
if has_prompt_style(checkpoint_dir)
else PromptStyle.from_config(config)
)
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
else:
prompt_style = PromptStyle.from_config(config)

kv_cache_initialized = False
if generate_strategy is None:
with fabric.init_module(empty_init=(num_devices > 1)):
if distribute == "auto":
if torch.cuda.is_available():
accelerator = "cuda"
elif torch.backends.mps.is_available():
accelerator = "mps"
else:
accelerator = "cpu"

fabric = L.Fabric(
accelerator=accelerator,
devices=1,
precision=get_default_supported_precision(training=False),
)

with fabric.init_module(empty_init=False):
model = GPT(config)
model.eval()
preprocessor = Preprocessor(tokenizer, device=fabric.device)

if checkpoint_dir is not None:
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
load_checkpoint(fabric, model, checkpoint_path)

model = fabric.setup_module(model)

else:
preprocessor = Preprocessor(tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")
model = None
fabric = None

return cls(
model=model, preprocessor=preprocessor, prompt_style=prompt_style,
config=config, checkpoint_dir=checkpoint_dir, fabric=fabric, generate_strategy=None,
kv_cache_initialized=False, fixed_kv_cache_size=False
)

def distribute(
self,
accelerator: Literal["cpu", "cuda", "auto"] = "auto",
devices: Union[int, List[int]] = 1,
precision: Optional[Any] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
generate_strategy: Optional[Literal["sequential"]] = None,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None
):
"""
Moves the model onto specified devices for single-GPU or multi-GPU inference
accelerator: Which device type to load the model on ("cpu", "gpu", "mps", "cuda", or "auto")
devices: The number of devices (1, 2, etc.) or device IDs (e.g., [0, 2] to use the first and third GPU).
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
For instance, "32-true", "16-mixed", "16-true", "bf16-mixed", "bf16-true".
For more details, see https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
generate_strategy: Whether to use a sequential model generation strategy. The "sequential" settings allows running
models that wouldn't fit in a single card by partitioning the transformer blocks across
all devices and running them sequentially. Sequential generation may be slower but allows using larger models.
Note that sequential generation sets `fixed_kv_cache_size="max_model_supported"`. You can set it to a lower integer
value, `fixed_kv_cache_size=256` to reduce memory memory. The `fixed_kv_cache_size` value determins the maximum number
of tokens that can be returned via `llm.generate(...)`.
fixed_kv_cache_size: If set to an integer value or "max_model_supported" is set, the kv-cache won't be resized dynamically
during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy="sequential`.
Note that the chosen `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned in `llm.generate(...)`.
"""

if self.checkpoint_dir is None:
raise NotImplementedError(
"The LLM was initialized with init='random' but .distribute() "
"currently only supports pretrained weights."
)

allowed_accelerators = {"cpu", "gpu", "cuda", "mps", "auto"}
if accelerator not in allowed_accelerators:
raise ValueError(f"Invalid accelerator: {accelerator}. Must be one of {allowed_accelerators}.")

if accelerator == "auto":
if torch.cuda.is_available():
accelerator = "cuda"
elif torch.backends.mps.is_available():
accelerator = "mps"
else:
accelerator = "cpu"

if generate_strategy == "sequential" and accelerator not in ("cuda", "gpu"):
raise NotImplementedError("generate_strategy='sequential' is only supported for accelerator='cuda'|'gpu.")

#if generate_strategy == "sequential" and init != "pretrained":
# raise NotImplementedError("generate_strategy='sequential' is only supported for init='pretrained'.")

num_devices = calculate_number_of_devices(devices)

if generate_strategy is None and num_devices > 1:
raise NotImplementedError(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'."
)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("The combination of quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

fabric = L.Fabric(
accelerator=accelerator,
devices=1, # Otherwise sequential wouldn't work, see litgpt/generate/sequentially.py
# devices=devices,
precision=precision,
plugins=plugins
)

self.kv_cache_initialized = False
if generate_strategy is None:
with fabric.init_module(empty_init=(num_devices > 1)):
model = GPT(self.config)
model.eval()

if self.checkpoint_dir is not None:
load_checkpoint(fabric, model, self.checkpoint_dir / "lit_model.pth")

model = fabric.setup_module(model)

if fixed_kv_cache_size is not None:
if fixed_kv_cache_size is None or fixed_kv_cache_size == "max_model_supported":
kv_cache_size = model.max_seq_length
else:
kv_cache_size = fixed_kv_cache_size
model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)
kv_cache_initialized = True
self.kv_cache_initialized = True
self.fixed_kv_cache_size = fixed_kv_cache_size

elif generate_strategy == "sequential":
# cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced
Expand All @@ -224,10 +289,10 @@ def load(
print(f"Using {total_devices} devices", file=sys.stderr)

with fabric.init_tensor(), torch.device("meta"):
model = GPT(config)
model = GPT(self.config)

model.eval()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
model = fabric.setup_module(model, move_to_device=False)

Expand All @@ -238,17 +303,15 @@ def load(
else:
kv_cache_size = fixed_kv_cache_size
model = sequential(model, fabric.device, kv_cache_size, total_devices)
kv_cache_initialized = True
self.fixed_kv_cache_size = fixed_kv_cache_size
self.kv_cache_initialized = True

else:
raise ValueError(f"Unsupported generate_strategy: {generate_strategy}")

return cls(
model=model, tokenizer=tokenizer, devices=devices,
prompt_style=prompt_style, checkpoint_dir=checkpoint_dir, fabric=fabric,
generate_strategy=generate_strategy, kvcache_initialized=kv_cache_initialized,
fixed_kv_cache_size=fixed_kv_cache_size
)
self.model = model
self.fabric = fabric
self.preprocessor.device = fabric.device

@torch.inference_mode()
def generate(
Expand Down Expand Up @@ -289,14 +352,16 @@ def generate(
At the moment, this setting is slower and may use more memory than the non-streaming version.
We plan to resolve this in the future.
"""
assert self.model is not None

prompt = self.prompt_style.apply(prompt)
input_ids = self.preprocessor.encode(prompt)
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens

if not self.kvcache_initialized:
if not self.kv_cache_initialized:
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=self.fabric.device)
self.kvcache_initialized = True
self.kv_cache_initialized = True

# Dynamically grow the kv cache size if necessary
if self.fixed_kv_cache_size is None and self.prev_generated_seq_length < max_returned_tokens:
Expand Down
Loading

0 comments on commit 9722fdb

Please sign in to comment.