Skip to content

Commit

Permalink
Refactor the TEQ to align with torch 3.x new API (#1766)
Browse files Browse the repository at this point in the history
Refactor TEQuantizer

Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Apr 30, 2024
1 parent f67e861 commit 099b7a4
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 87 deletions.
110 changes: 64 additions & 46 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@
# limitations under the License.
#

import copy
from typing import Any

import torch
import transformers

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module

__all__ = ["teq_quantize", "TEQuantizer"]
__all__ = ["TrainableEquivalentTransformation", "TEQuantizer"]


class TrainableEquivalentTransformation:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""

class TEQuantizer:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS_PREFIX = "_prepare_"

def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
"""
Expand All @@ -41,16 +48,20 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
self.folding = folding
self.example_inputs = example_inputs
self.device = self._get_device()
self.dtype = self._get_dtype()
self.model.eval()
self.trained_alphas = {}
self.absorb_to_layer = absorb_to_layer
self._post_initialized = False

def _post_init(self):
self.dtype = self._get_dtype()
self.model.to(self.device)
self.model.eval()
self._post_initialized = True

def _get_device(self):
"""Get the model device
:return:Model device."""
device = get_device()
self.model.to(device)
return device

def _get_dtype(self):
Expand All @@ -62,6 +73,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
to the paper for more details
:param sqrt_w_init: use sqrt weight to init."""

if not self._post_initialized:
self._post_init()
# freeze model.
for n, p in self.model.named_parameters():
p.requires_grad = False
Expand Down Expand Up @@ -117,6 +130,9 @@ def add_tuning_scale(self, sqrt_w_init=False):
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
)
set_module(self.model, n, wrapper_module)
# Attach the weight config captured at prepare stage to the model
self.model._weight_config = self.weight_config
self.model._trained_alphas = self.trained_alphas

@torch.no_grad()
def _absorb_scales(self, layer, scale, layer_name=""):
Expand Down Expand Up @@ -204,6 +220,8 @@ def _scale_layer_weight(self, layer, scale): ##input channel
@torch.no_grad()
def transform(self):
"""Apply alpha/scale."""
if not self._post_initialized:
self._post_init()
for ln_name, layer_names in self.absorb_to_layer.items():
module = get_module(self.model, ln_name)
scale = self.trained_alphas[ln_name]
Expand Down Expand Up @@ -309,43 +327,43 @@ def save(self, save_scale_file="", save_state_dict_file=""):
torch.save(self.model.state_dict(), save_state_dict_file)


def teq_quantize(
model, weight_config={}, absorb_to_layer={}, folding=True, dataloader=None, calib_func=None, example_inputs=None
):
"""Run TEQ weight-only quantization."""
assert isinstance(model, torch.nn.Module), "only support torch module"
logger.info("TEQ quantizing start.")
if example_inputs is None:
if dataloader is None: # pragma: no cover
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
try:
for idx, (input, label) in enumerate(dataloader):
example_inputs = input
break
except: # pragma: no cover
for idx, input in enumerate(dataloader):
example_inputs = input
break

teq_quantizer = TEQuantizer(model, weight_config, absorb_to_layer, folding, example_inputs)

# 1. wrapper tuning scale to model
teq_quantizer.add_tuning_scale()

# 2. tuning
# custom train function, there calls calib_func
if calib_func: # pragma: no cover
calib_func(teq_quantizer.model)
else:
if dataloader is None: # pragma: no cover
assert False, "Please provide dataloader to train."
teq_quantizer.train(dataloader)

# 3. apply scale to model
teq_quantizer.transform()

# 4. get quantized model
teq_quantizer.quantize()

logger.info("TEQ quantizing done.")
return teq_quantizer.model
class TEQuantizer(Quantizer):

def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
super().__init__(quant_config=quant_config)
self.folding = folding
self.absorb_to_layer = absorb_to_layer
self.example_inputs = example_inputs
self._quantizer = TrainableEquivalentTransformation(
model=None,
weight_config=quant_config,
absorb_to_layer=absorb_to_layer,
folding=folding,
example_inputs=example_inputs,
)

def prepare(self, model, *args, **kwargs):
"""Prepares a given model for quantization.
Args:
model: A float model to be quantized.
Returns:
A prepared model.
"""
float_model = model
assert isinstance(model, torch.nn.Module), "only support torch module"
self._quantizer.model = float_model
logger.info("TEQ quantizing start.")
self._quantizer.add_tuning_scale()
for attr in self._quantizer._PREPARE_ATTRS:
setattr(float_model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, getattr(self._quantizer, attr))
return float_model

def convert(self, model, *args: Any, **kwargs: Any):
for attr in self._quantizer._PREPARE_ATTRS:
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
self._quantizer.model = model
self._quantizer.transform()
self._quantizer.quantize()
logger.info("TEQ quantizing done.")
return self._quantizer.model
19 changes: 7 additions & 12 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,17 @@ def awq_quantize_entry(
###################### TEQ Algo Entry ##################################
@register_algo(name=TEQ)
def teq_quantize_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], *args, **kwargs
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], mode: Mode, *args, **kwargs
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.teq import teq_quantize
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer

logger.info("Quantize model with the TEQ algorithm.")
weight_config = {}
absorb_to_layer = {}
example_inputs = kwargs.get("example_inputs", None)
assert example_inputs is not None, "Please provide example_inputs for TEQ quantization."
calib_func = kwargs.get("run_fn", None)
run_fn = kwargs.get("run_fn", None)
inplace = kwargs.get("inplace", True)
folding = True
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.dtype == "fp32":
Expand All @@ -318,16 +319,10 @@ def teq_quantize_entry(
absorb_to_layer = quant_config.absorb_to_layer
folding = quant_config.folding
assert isinstance(model, torch.nn.Module), "only support torch module"

model = teq_quantize(
model,
example_inputs=example_inputs,
folding=folding,
absorb_to_layer=absorb_to_layer,
calib_func=calib_func,
weight_config=weight_config,
quantizer = TEQuantizer(
quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs
)
logger.info("TEQ quantization done.")
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
return model


Expand Down
143 changes: 143 additions & 0 deletions test/3x/torch/algorithms/weight_only/test_teq_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import copy
import unittest

import torch
import transformers

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer
from neural_compressor.torch.quantization import quantize


def generate_random_corpus(nsamples=32):
meta_data = []
for _ in range(nsamples):
inp = torch.ones([1, 512], dtype=torch.long)
tar = torch.ones([1, 512], dtype=torch.long)
meta_data.append((inp, tar))
return meta_data


def train(
model,
train_steps=100,
lr=1e-3,
warmup_ratio=0.05,
gradient_accumulation_steps=1,
logging_steps=10,
betas=[0.9, 0.9],
weight_decay=0,
lr_scheduler_type="linear",
):
"""Train function."""
trained_alphas_list = [torch.ones([128], requires_grad=True)]
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)

lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
name=lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
num_training_steps=train_steps // gradient_accumulation_steps,
)

logger.info("start training")
model.train()
global_steps = 0
dataloader = generate_random_corpus()
while global_steps <= train_steps:
for inputs in dataloader:
if isinstance(inputs, torch.Tensor):
input_id = inputs
elif isinstance(inputs, dict):
input_id = inputs["input_ids"]
else:
input_id = inputs[0]
output = model(input_id, labels=input_id)
loss = output[0] / gradient_accumulation_steps
loss.backward()
global_steps += 1

if global_steps % logging_steps == 0:
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))

if global_steps % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

if global_steps >= train_steps: # pragma: no cover
break

logger.info("finish training")
model.eval()
return None


class TestTEQWeightOnlyQuant(unittest.TestCase):
@classmethod
def setUpClass(self):
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
torchscript=True,
)
self.gptj.seqlen = 512

def train_func(self):
pass

def test_teq(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
test_input = torch.ones([1, 512], dtype=torch.long)
model = copy.deepcopy(self.gptj)
out0 = model(test_input)

weight_config = {
# 'op_name': (bit, group_size, scheme)
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
}
absorb_dict = {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]}

quantizer = TEQuantizer(
quant_config=weight_config, folding=True, absorb_to_layer=absorb_dict, example_inputs=example_inputs
)
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
out1 = model(test_input)
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))

quant_config = {
"teq": {
"global": {
"dtype": "fp32",
},
"local": {
"transformer.h.0.mlp.fc_in": {
"dtype": "int",
"bits": 8,
"group_size": -1,
"use_sym": True,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
"transformer.h.0.mlp.fc_out": {
"dtype": "int",
"bits": 4,
"group_size": 32,
"use_sym": False,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
},
}
}
qdq_model = quantize(
model=copy.deepcopy(self.gptj), quant_config=quant_config, run_fn=train, example_inputs=example_inputs
)
self.assertTrue(isinstance(qdq_model, torch.nn.Module))
out2 = qdq_model(test_input)
self.assertTrue(torch.allclose(out1[0], out2[0]))
self.assertTrue(torch.allclose(out2[0], out0[0], atol=0.03))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 099b7a4

Please sign in to comment.