Skip to content

Commit

Permalink
[Misc][Quark] Upstream Quark format to VLLM (vllm-project#10765)
Browse files Browse the repository at this point in the history
Signed-off-by: kewang-xlnx <[email protected]>
Signed-off-by: kewang2 <[email protected]>
Co-authored-by: kewang2 <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored and abmfy committed Jan 24, 2025
1 parent 25915fb commit c84cf32
Show file tree
Hide file tree
Showing 32 changed files with 1,264 additions and 70 deletions.
30 changes: 30 additions & 0 deletions tests/quantization/test_quark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
"""

import torch

from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8)


def test_quark_fp8(vllm_runner):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _verify_quantization(self) -> None:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
"compressed-tensors", "experts_int8", "quark"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod"
"HQQMarlinMethod", "QuarkLinearMethod"
]


Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"experts_int8",
"neuron_quant",
"ipex",
"quark"
]


Expand All @@ -34,6 +35,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")

# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig

from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
Expand Down Expand Up @@ -79,6 +82,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
"quark": QuarkConfig
}

return method_to_config[quantization]
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ def get_quant_method(self, layer: torch.nn.Module,
method.
"""
raise NotImplementedError

def get_cache_scale(self, name: str) -> Optional[str]:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,22 @@ def get_scheme(
self._check_scheme_supported(scheme.get_min_capability())
return scheme

def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def triton_scaled_mm(input: torch.Tensor,
assert N > 0 and K > 0 and M > 0
assert weight.shape[0] == K
assert input.dtype == weight.dtype

scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b

assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,6 @@ def _find_first_match(value: str,
return None


def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None


def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
Expand Down
Empty file.
Loading

0 comments on commit c84cf32

Please sign in to comment.