Skip to content

Commit

Permalink
New HIGGS quantization interfaces, JIT kernel compilation support. (#…
Browse files Browse the repository at this point in the history
…36148)

* new flute

* new higgs working

* small adjustments

* progress and quallity

* small updates

* style

---------

Co-authored-by: Andrey Panferov <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Mohamed Mekkouri <[email protected]>
  • Loading branch information
4 people authored Feb 14, 2025
1 parent 15ec971 commit 5f726f8
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 81 deletions.
24 changes: 11 additions & 13 deletions src/transformers/integrations/higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@


if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed
from flute.tune import TuneMetaData, qgemm_v2

if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform

if is_flute_available():
import flute.utils
from flute.integrations.higgs import prepare_data_transposed


def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
Expand Down Expand Up @@ -464,14 +461,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256

# Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 64):
codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
for i in range(0, weight.shape[0], 16):
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight

codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size)

weight, scales, tables, tables2 = prepare_data_transposed(
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype),
Expand All @@ -480,13 +477,15 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256
vector_size=p,
dtype=dtype,
device=device,
check_correctness=False,
)

return {
"weight": weight,
"scales": scales,
"tables": tables,
"tables2": tables2.view(dtype=torch.float16),
"tune_metadata": tune_metadata,
}


Expand All @@ -508,7 +507,6 @@ def __init__(
self.num_bits = num_bits
self.group_size = group_size
self.hadamard_size = hadamard_size
self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False)

assert in_features % group_size == 0
assert num_bits in [2, 3, 4]
Expand All @@ -531,23 +529,23 @@ def __init__(
self.register_parameter("bias", None)

self.workspace = None # must be set externally to be reused among layers
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent

def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size)

if self.workspace is None:
raise Exception("Workspace must be set before calling forward")

return flute.qgemm_hadamard(
return qgemm_v2(
x,
self.weight,
self.scales,
self.tables,
self.tables2.view(dtype=torch.float32),
self.workspace,
self.num_bits,
self.group_size,
self.hadamard_size,
self.tune_metadata,
hadamard_size=self.hadamard_size,
)


Expand Down
98 changes: 34 additions & 64 deletions src/transformers/quantizers/quantizer_higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ..utils.logging import tqdm
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name

Expand All @@ -30,20 +31,6 @@
logger = logging.get_logger(__name__)


def get_num_sms_from_device(device):
target_device_cc = torch.cuda.get_device_capability(device=device)
if target_device_cc == (8, 6):
return 84
elif target_device_cc == (8, 0):
return 108
elif target_device_cc == (8, 9):
return 128
else:
raise NotImplementedError(
f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus"
)


class HiggsHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
Expand Down Expand Up @@ -115,26 +102,24 @@ def create_quantized_param(
self.quantization_config.group_size,
self.quantization_config.hadamard_size,
)

del param_value

module, tensor_name = get_module_from_name(model, param_name)
module, _ = get_module_from_name(model, param_name)
module_name = ".".join(param_name.split(".")[:-1])
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.quantization_config.tune_metadata[module_name] = value.to_dict()
else:
raise ValueError(f"Unexpected key {key} in module {module}")

if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)

module.num_sms_packed = torch.nn.Parameter(
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
requires_grad=False,
)

def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
Expand All @@ -149,57 +134,42 @@ def _process_model_before_weight_loading(
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
import flute.utils
from flute.tune import TuneMetaData, maybe_tune_and_repack
from flute.utils import make_workspace_streamk

from ..integrations import HiggsLinear

flute_workspaces = {}
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
device=module.weight.device
)
module.workspace = flute_workspaces[module.weight.device]

# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device):
new_device = module.weight.device
new_num_sms = get_num_sms_from_device(new_device)
module.weight.data = flute.utils.pack(
flute.utils.unpack(
weight=module.weight.data,
scales=module.scales.data,
workspace=module.workspace,
num_bits=module.num_bits,
group_size=module.group_size,
num_sms_packed=module.num_sms_packed.item(),
).T.contiguous(),
module.num_bits,
module.group_size,
)
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(new_num_sms, device=new_device, dtype=torch.int32),
requires_grad=False,
)
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
module.workspace = flute_workspaces[module.weight.device]

# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
weight=module.weight.data,
scales=module.scales.data,
metadata=module.tune_metadata,
)
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import HiggsLinear

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}

def should_update(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in higgs_names)

return [key for key in missing_keys if not should_update(key)]

@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def is_flax_available():

def is_flute_available():
try:
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.3.0"
return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.4.1"
except importlib.metadata.PackageNotFoundError:
return False

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,8 @@ class HiggsConfig(QuantizationConfigMixin):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
tune_metadata ('dict', *optional*, defaults to {}):
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
"""

def __init__(
Expand All @@ -1413,16 +1415,20 @@ def __init__(
modules_to_not_convert: Optional[List[str]] = None,
hadamard_size: int = 512,
group_size: int = 256,
tune_metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"]
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
self.bits = bits
self.p = p
self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size
self.group_size = group_size
self.tune_metadata = tune_metadata

self.post_init()

Expand Down
6 changes: 3 additions & 3 deletions tests/quantization/higgs/test_higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def test_from_dict(self):
@require_accelerate
# @require_read_token
class HiggsTest(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3.1-8B"
model_name = "unsloth/Llama-3.2-1B"

input_text = "A quick brown fox jumps over the"
input_text = "Font test: A quick brown fox jumps over the"
max_new_tokens = 2

EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog"
EXPECTED_OUTPUT = "Font test: A quick brown fox jumps over the lazy dog"

device_map = "cuda"

Expand Down

0 comments on commit 5f726f8

Please sign in to comment.