Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support tensor parallelism for GGUF quantization #7520

Merged
merged 9 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions tests/models/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported

Expand All @@ -20,7 +21,7 @@
MODELS = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")),
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
Expand All @@ -39,22 +40,36 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_models(
num_gpus_available,
vllm_runner,
example_prompts,
model,
dtype: str,
max_tokens: int,
num_logprobs: int,
tp_size: int,
) -> None:
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

original_model, gguf_model = model

tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Run unquantized model.
with vllm_runner(model_name=original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as original_model:
tensor_parallel_size=tp_size) as original_model:

original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
Expand All @@ -63,8 +78,7 @@ def test_models(
with vllm_runner(model_name=gguf_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as gguf_model:
tensor_parallel_size=tp_size) as gguf_model:
gguf_outputs = gguf_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)

Expand Down
26 changes: 20 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,16 @@ def weight_loader(self,
loaded_shard_id

if is_gguf_weight:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
Expand Down Expand Up @@ -863,8 +868,13 @@ def weight_loader(self,
param, orig_qkv_offsets, loaded_shard_id)

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
Expand Down Expand Up @@ -976,6 +986,7 @@ def __init__(self,

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)

# Special case for GGUF
Expand All @@ -986,7 +997,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
weight_shape = list(loaded_weight.shape)
if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand Down Expand Up @@ -39,9 +38,6 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
if get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"GGUF quantization hasn't supported tensor parallelism yet.")
return cls()

def get_quant_method(self, layer: torch.nn.Module,
Expand Down
Loading