diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index 5971179f01211..196cd88e039a1 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -20,7 +21,7 @@ MODELS = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), @@ -39,22 +40,36 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_models( + num_gpus_available, vllm_runner, example_prompts, model, dtype: str, max_tokens: int, num_logprobs: int, + tp_size: int, ) -> None: + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + original_model, gguf_model = model + tokenizer = AutoTokenizer.from_pretrained(original_model) + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + # Run unquantized model. with vllm_runner(model_name=original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as original_model: + tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) @@ -63,8 +78,7 @@ def test_models( with vllm_runner(model_name=gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as gguf_model: + tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b4cc6daa3c41e..3824ed3570aeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -507,11 +507,16 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -863,8 +868,13 @@ def weight_loader(self, param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + input_dim = getattr(param, "input_dim", None) input_size = loaded_weight.shape[input_dim] param_data = param_data.narrow(input_dim, 0, input_size) @@ -976,6 +986,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) # Special case for GGUF @@ -986,7 +997,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a4e0a4d509608..a6a1ed5b0dee5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -39,9 +38,6 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": - if get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "GGUF quantization hasn't supported tensor parallelism yet.") return cls() def get_quant_method(self, layer: torch.nn.Module,