Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] add marlin inference kernel #310

Merged
merged 8 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/quantization/basic_usage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.quantization import FORMAT
from transformers import AutoTokenizer

pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
Expand Down
3 changes: 1 addition & 2 deletions examples/quantization/basic_usage_wikitext2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
import torch.nn as nn
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer
from datasets import load_dataset

pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "TinyLlama-1.1B-Chat-v1.0-4bit-128g"
Expand Down
13 changes: 6 additions & 7 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import re
import shutil
from os.path import basename, isfile, join
from typing import Dict, List, Optional, Union

Expand All @@ -18,7 +17,7 @@
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.utils.generic import ContextManagers

from ..nn_modules.qlinear.qlinear_qbits import qbits_dtype, QBitsQuantLinear
from ..nn_modules.qlinear.qlinear_qbits import QBitsQuantLinear, qbits_dtype
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import (FORMAT, FORMAT_FIELD_JSON, META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL,
MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig)
Expand All @@ -29,10 +28,10 @@
from ..utils.marlin import (_validate_marlin_compatibility,
_validate_marlin_device_support, prepare_model_for_marlin_load)
from ..utils.model import (auto_dtype_from_config, check_to_quantized, convert_gptq_v1_to_v2_format,
convert_gptq_v2_to_v1_format, find_layers, get_checkpoints, get_device,
convert_gptq_v2_to_v1_format, copy_py_files, find_layers, get_checkpoints, get_device,
get_module_by_name_prefix, get_module_by_name_suffix, get_moe_layer_modules,
gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model,
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes, copy_py_files)
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
from ..version import __version__
from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS

Expand Down Expand Up @@ -1177,8 +1176,8 @@ def skip(*args, **kwargs):
layers,
quantize_config.bits,
quantize_config.group_size,
backend=backend.AUTO if backend == BACKEND.MARLIN or backend == BACKEND.BITBLAS else backend,
format=FORMAT.GPTQ_V2,
backend=backend.AUTO if (backend == BACKEND.MARLIN and quantize_config.format == FORMAT.MARLIN) or backend == BACKEND.BITBLAS else backend,
format=quantize_config.format,
desc_act=quantize_config.desc_act,
)
if preload_qlinear_kernel == QBitsQuantLinear:
Expand Down Expand Up @@ -1247,7 +1246,7 @@ def skip(*args, **kwargs):
load_checkpoint_in_model = True
quantize_config.runtime_format = FORMAT.GPTQ_V2

if backend == BACKEND.MARLIN:
if backend == BACKEND.MARLIN and quantize_config.format == FORMAT.MARLIN:
if is_sharded:
raise ValueError(
"The loading of sharded checkpoints with Marlin is currently not supported."
Expand Down
312 changes: 312 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
# License: GPTQModel/licenses/LICENSE.apache
# Adapted from vllm at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/gptq_marlin.py

from typing import Any, Dict, List, Optional, Tuple

import gptqmodel_marlin_cuda_inference
import torch
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from torch.nn.parameter import Parameter

GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16

def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: Optional[Dict[str, Any]],
):
"""Set attributes on a weight tensor.

This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.

Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if weight_attrs is None:
return
for key, value in weight_attrs.items():
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)

def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)

def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)

def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL

return torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)

def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices

def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)

# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str,
new_t: torch.Tensor) -> None:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t

def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int) -> torch.Tensor:

scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()

return s

def get_scale_perms():
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single

def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )

output = gptqmodel_marlin_cuda_inference.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
reshaped_x.shape[0],
output_size_per_partition,
input_size_per_partition,
is_k_full,
False)

if bias is not None:
output.add_(bias) # In-place add

return output.reshape(out_shape)

class MarlinInferenceQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4, 8]
SUPPORTED_GROUP_SIZE = [-1, 32, 64, 128]
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.pack_factor = 32 // bits # packed into int32

if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False

# Normalize group_size
if group_size != -1:
group_size = group_size
else:
group_size = infeatures

self.bits = bits
self.group_size = group_size
self.desc_act = desc_act

# Determine sharding
if marlin_repeat_scales_on_all_ranks(desc_act,
group_size,
is_row_parallel=False):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = infeatures // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = infeatures // group_size

# Quantized weights
qweight = Parameter(
torch.empty(
infeatures // self.pack_factor,
outfeatures,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.pack_factor,
},
)

# Activation order
g_idx = Parameter(
torch.empty(
infeatures,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
"input_dim": 0,
"ignore_warning": True
},
)

# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
outfeatures,
dtype=torch.float16,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)

# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
outfeatures // self.pack_factor,
dtype=torch.int32,
# device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.pack_factor,
},
)

self.register_parameter("qweight", qweight)
self.register_parameter("g_idx", g_idx)
self.register_parameter("scales", scales)
self.register_parameter("qzeros", qzeros)
self.infeatures = infeatures
self.outfeatures = outfeatures
self.is_k_full = marlin_is_k_full(desc_act, is_row_parallel=False)

if bias:
self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.half))
else:
self.bias = None

def post_init(self):
device = self.qweight.device
self.validate_device(device.type)

# Allocate marlin workspace
self.workspace = marlin_make_workspace(
self.outfeatures, device)

# Handle sorting for activation reordering if needed.
if self.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(self.g_idx)
self.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(self, "g_idx", g_idx)
else:
self.g_idx = marlin_make_empty_g_idx(device)
self.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
self.zp = marlin_make_empty_g_idx(device)

# Repack weights from autogptq format to marlin format.
marlin_qweight = gptqmodel_marlin_cuda_inference.gptq_marlin_repack(
self.qweight,
self.g_idx_sort_indices,
self.infeatures,
self.outfeatures,
self.bits)
replace_tensor(self, "qweight", marlin_qweight)

# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
self.scales,
size_k=self.infeatures,
size_n=self.outfeatures,
group_size=self.group_size)
replace_tensor(self, "scales", marlin_scales)

def forward(self, A: torch.Tensor):
if A.dtype != torch.float16:
A = A.half()

return apply_gptq_marlin_linear(
input=A,
weight=self.qweight,
weight_scale=self.scales,
weight_zp=self.zp,
g_idx=self.g_idx,
g_idx_sort_indices=self.g_idx_sort_indices,
workspace=self.workspace,
num_bits=self.bits,
output_size_per_partition=self.outfeatures,
input_size_per_partition=self.infeatures,
is_k_full=self.is_k_full,
bias=self.bias)
2 changes: 1 addition & 1 deletion gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import logging
from dataclasses import dataclass, field, fields
from importlib.metadata import version as pkg_version
from os.path import isdir, join
from typing import Any, Dict, Optional, Tuple

from packaging import version
from importlib.metadata import version as pkg_version
from transformers.utils.hub import cached_file

logger = logging.getLogger(__name__)
Expand Down
Loading
Loading