Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the TEQ to align with torch 3.x new API #1766

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 64 additions & 46 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@
# limitations under the License.
#

import copy
from typing import Any

import torch
import transformers

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_device, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module

__all__ = ["teq_quantize", "TEQuantizer"]
__all__ = ["TrainableEquivalentTransformation", "TEQuantizer"]


class TrainableEquivalentTransformation:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""

class TEQuantizer:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS_PREFIX = "_prepare_"

def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
"""
Expand All @@ -41,16 +48,20 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
self.folding = folding
self.example_inputs = example_inputs
self.device = self._get_device()
self.dtype = self._get_dtype()
self.model.eval()
self.trained_alphas = {}
self.absorb_to_layer = absorb_to_layer
self._post_initialized = False

def _post_init(self):
self.dtype = self._get_dtype()
self.model.to(self.device)
self.model.eval()
self._post_initialized = True

def _get_device(self):
"""Get the model device
:return:Model device."""
device = get_device()
self.model.to(device)
return device

def _get_dtype(self):
Expand All @@ -62,6 +73,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
to the paper for more details
:param sqrt_w_init: use sqrt weight to init."""

if not self._post_initialized:
self._post_init()
# freeze model.
for n, p in self.model.named_parameters():
p.requires_grad = False
Expand Down Expand Up @@ -117,6 +130,9 @@ def add_tuning_scale(self, sqrt_w_init=False):
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
)
set_module(self.model, n, wrapper_module)
# Attach the weight config captured at prepare stage to the model
self.model._weight_config = self.weight_config
self.model._trained_alphas = self.trained_alphas

@torch.no_grad()
def _absorb_scales(self, layer, scale, layer_name=""):
Expand Down Expand Up @@ -204,6 +220,8 @@ def _scale_layer_weight(self, layer, scale): ##input channel
@torch.no_grad()
def transform(self):
"""Apply alpha/scale."""
if not self._post_initialized:
self._post_init()
for ln_name, layer_names in self.absorb_to_layer.items():
module = get_module(self.model, ln_name)
scale = self.trained_alphas[ln_name]
Expand Down Expand Up @@ -309,43 +327,43 @@ def save(self, save_scale_file="", save_state_dict_file=""):
torch.save(self.model.state_dict(), save_state_dict_file)


def teq_quantize(
model, weight_config={}, absorb_to_layer={}, folding=True, dataloader=None, calib_func=None, example_inputs=None
):
"""Run TEQ weight-only quantization."""
assert isinstance(model, torch.nn.Module), "only support torch module"
logger.info("TEQ quantizing start.")
if example_inputs is None:
if dataloader is None: # pragma: no cover
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
try:
for idx, (input, label) in enumerate(dataloader):
example_inputs = input
break
except: # pragma: no cover
for idx, input in enumerate(dataloader):
example_inputs = input
break

teq_quantizer = TEQuantizer(model, weight_config, absorb_to_layer, folding, example_inputs)

# 1. wrapper tuning scale to model
teq_quantizer.add_tuning_scale()

# 2. tuning
# custom train function, there calls calib_func
if calib_func: # pragma: no cover
calib_func(teq_quantizer.model)
else:
if dataloader is None: # pragma: no cover
assert False, "Please provide dataloader to train."
teq_quantizer.train(dataloader)

# 3. apply scale to model
teq_quantizer.transform()

# 4. get quantized model
teq_quantizer.quantize()

logger.info("TEQ quantizing done.")
return teq_quantizer.model
class TEQuantizer(Quantizer):

def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
super().__init__(quant_config=quant_config)
self.folding = folding
self.absorb_to_layer = absorb_to_layer
self.example_inputs = example_inputs
self._quantizer = TrainableEquivalentTransformation(
model=None,
weight_config=quant_config,
absorb_to_layer=absorb_to_layer,
folding=folding,
example_inputs=example_inputs,
)

def prepare(self, model, *args, **kwargs):
"""Prepares a given model for quantization.

Args:
model: A float model to be quantized.
Returns:
A prepared model.
"""
float_model = model
assert isinstance(model, torch.nn.Module), "only support torch module"
self._quantizer.model = float_model
logger.info("TEQ quantizing start.")
self._quantizer.add_tuning_scale()
for attr in self._quantizer._PREPARE_ATTRS:
setattr(float_model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, getattr(self._quantizer, attr))
return float_model

def convert(self, model, *args: Any, **kwargs: Any):
for attr in self._quantizer._PREPARE_ATTRS:
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
self._quantizer.model = model
self._quantizer.transform()
self._quantizer.quantize()
logger.info("TEQ quantizing done.")
return self._quantizer.model
19 changes: 7 additions & 12 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,17 @@ def awq_quantize_entry(
###################### TEQ Algo Entry ##################################
@register_algo(name=TEQ)
def teq_quantize_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], *args, **kwargs
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], mode: Mode, *args, **kwargs
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.teq import teq_quantize
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer

logger.info("Quantize model with the TEQ algorithm.")
weight_config = {}
absorb_to_layer = {}
example_inputs = kwargs.get("example_inputs", None)
assert example_inputs is not None, "Please provide example_inputs for TEQ quantization."
calib_func = kwargs.get("run_fn", None)
run_fn = kwargs.get("run_fn", None)
inplace = kwargs.get("inplace", True)
folding = True
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.dtype == "fp32":
Expand All @@ -318,16 +319,10 @@ def teq_quantize_entry(
absorb_to_layer = quant_config.absorb_to_layer
folding = quant_config.folding
assert isinstance(model, torch.nn.Module), "only support torch module"

model = teq_quantize(
model,
example_inputs=example_inputs,
folding=folding,
absorb_to_layer=absorb_to_layer,
calib_func=calib_func,
weight_config=weight_config,
quantizer = TEQuantizer(
quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs
)
logger.info("TEQ quantization done.")
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
return model


Expand Down
143 changes: 143 additions & 0 deletions test/3x/torch/algorithms/weight_only/test_teq_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import copy
import unittest

import torch
import transformers

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer
from neural_compressor.torch.quantization import quantize


def generate_random_corpus(nsamples=32):
meta_data = []
for _ in range(nsamples):
inp = torch.ones([1, 512], dtype=torch.long)
tar = torch.ones([1, 512], dtype=torch.long)
meta_data.append((inp, tar))
return meta_data


def train(
model,
train_steps=100,
lr=1e-3,
warmup_ratio=0.05,
gradient_accumulation_steps=1,
logging_steps=10,
betas=[0.9, 0.9],
weight_decay=0,
lr_scheduler_type="linear",
):
"""Train function."""
trained_alphas_list = [torch.ones([128], requires_grad=True)]
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)

lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
name=lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
num_training_steps=train_steps // gradient_accumulation_steps,
)

logger.info("start training")
model.train()
global_steps = 0
dataloader = generate_random_corpus()
while global_steps <= train_steps:
for inputs in dataloader:
if isinstance(inputs, torch.Tensor):
input_id = inputs
elif isinstance(inputs, dict):
input_id = inputs["input_ids"]
else:
input_id = inputs[0]
output = model(input_id, labels=input_id)
loss = output[0] / gradient_accumulation_steps
loss.backward()
global_steps += 1

if global_steps % logging_steps == 0:
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))

if global_steps % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

if global_steps >= train_steps: # pragma: no cover
break

logger.info("finish training")
model.eval()
return None


class TestTEQWeightOnlyQuant(unittest.TestCase):
@classmethod
def setUpClass(self):
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
torchscript=True,
)
self.gptj.seqlen = 512

def train_func(self):
pass

def test_teq(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
test_input = torch.ones([1, 512], dtype=torch.long)
model = copy.deepcopy(self.gptj)
out0 = model(test_input)

weight_config = {
# 'op_name': (bit, group_size, scheme)
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
}
absorb_dict = {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]}

quantizer = TEQuantizer(
quant_config=weight_config, folding=True, absorb_to_layer=absorb_dict, example_inputs=example_inputs
)
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
out1 = model(test_input)
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))

quant_config = {
"teq": {
"global": {
"dtype": "fp32",
},
"local": {
"transformer.h.0.mlp.fc_in": {
"dtype": "int",
"bits": 8,
"group_size": -1,
"use_sym": True,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
"transformer.h.0.mlp.fc_out": {
"dtype": "int",
"bits": 4,
"group_size": 32,
"use_sym": False,
"folding": True,
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
},
},
}
}
qdq_model = quantize(
model=copy.deepcopy(self.gptj), quant_config=quant_config, run_fn=train, example_inputs=example_inputs
)
self.assertTrue(isinstance(qdq_model, torch.nn.Module))
out2 = qdq_model(test_input)
self.assertTrue(torch.allclose(out1[0], out2[0]))
self.assertTrue(torch.allclose(out2[0], out0[0], atol=0.03))


if __name__ == "__main__":
unittest.main()
Loading
Loading