Skip to content

Commit

Permalink
Add export support for TEQ (#1910)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 11, 2024
1 parent 16a7b11 commit 4a45093
Showing 1 changed file with 63 additions and 11 deletions.
74 changes: 63 additions & 11 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger
from neural_compressor.torch.utils import get_accelerator, get_model_device, is_transformers_imported, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module
Expand Down Expand Up @@ -265,18 +265,70 @@ def transform(self):
set_module(self.model, n, m.orig_layer)

@torch.no_grad()
def quantize(self):
def quantize(self, **kwargs):
"""quantization."""

for n, m in self.model.named_modules():
if self.weight_config.get(n) is None: # pragma: no cover
logger.info(f"quantize layer {n} not in weight config, skip.")
use_optimum_format = kwargs.get("use_optimum_format", True)
device = get_accelerator().current_device_name()
model_device = get_model_device(self.model) # return model on the same device
model = self.model
for name, m in model.named_modules():
if self.weight_config.get(name) is None: # pragma: no cover
logger.info(f"quantize layer {name} not in weight config, skip.")
continue
num_bits = self.weight_config[n]["bits"]
group_size = self.weight_config[n]["group_size"]
scheme = self.weight_config[n]["scheme"]
num_bits = self.weight_config[name]["bits"]
group_size = self.weight_config[name]["group_size"]
scheme = self.weight_config[name]["scheme"]
group_dim = self.weight_config[name].get("group_dim", 1)
# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
else: # pragma: no cover
transpose = group_dim == 0
if transpose: # pragma: no cover
weight = m.weight.detach().T.contiguous()
else:
weight = m.weight.detach()
if isinstance(m, torch.nn.Linear): # pragma: no cover
quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)
int_weight, scale, zp = quant_tensor(
weight.data,
num_bits=num_bits,
group_size=group_size,
scheme=scheme,
return_int=True,
)
int_weight = int_weight.t_().contiguous() if transpose else int_weight
scale = scale.t_().contiguous() if transpose else scale
zp = zp.t_().contiguous() if transpose and zp is not None else zp
if isinstance(m, torch.nn.Linear):
in_features = m.in_features
out_features = m.out_features
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
in_features = m.weight.shape[0]
out_features = m.weight.shape[1]
int_weight = int_weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp
from .modules import WeightOnlyLinear

new_module = WeightOnlyLinear(
in_features,
out_features,
bits=num_bits,
group_size=group_size,
zp=zp is not None,
bias=m.bias is not None,
use_optimum_format=use_optimum_format,
device=device,
)
new_module.pack(int_weight, scale, zp, m.bias)
if name == "":
return new_module
else:
set_module(model, name, new_module)
# Move modules back to the model device layer-by-layer
m.to(model_device)
new_module.to(model_device)
self.model = model

def save(self, save_scale_file="", save_state_dict_file=""):
"""
Expand Down Expand Up @@ -328,6 +380,6 @@ def convert(self, model, *args: Any, **kwargs: Any):
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
self._quantizer.model = model
self._quantizer.transform()
self._quantizer.quantize()
self._quantizer.quantize(**kwargs)
logger.info("TEQ quantizing done.")
return self._quantizer.model

0 comments on commit 4a45093

Please sign in to comment.