diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 0b24d075512..a24081be39e 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -15,11 +15,10 @@ # Copied from neural_compressor/adaptor/torch_utils/awq.py import copy -from functools import partial import torch -from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils import get_device, logger from .modules import MulLinear from .utility import ( @@ -33,6 +32,8 @@ set_module, ) +__all__ = ["awq_quantize"] + def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}): """Get absorbed layer per block. @@ -122,10 +123,13 @@ def __init__( use_full_range=False, weight_config={}, ): + self.example_inputs = example_inputs + self.model = model if example_inputs is None: assert dataloader is not None, "datalaoder or example_inputs is required." self.example_inputs = get_example_input(dataloader) + self._move_model_and_data_to_device() # Step 1: get hidden states and kwargs of first block. self.total_block_args, self.total_block_kwargs = get_hidden_states( model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func @@ -139,7 +143,12 @@ def __init__( self.scheme = scheme self.use_full_range = use_full_range self.weight_config = weight_config - self.model = model + + def _move_model_and_data_to_device(self): + # Put the model and example_inputs into target device + device = get_device() + self.model.to(device) + self.example_inputs = self.example_inputs.to(device) def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, return_int=False): """Execute AWQ quantization. diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 34ed8ff30e9..c8ad90c0570 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -19,11 +19,13 @@ import torch import transformers -from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils import get_device, logger from .modules import MulLinear, TEQLinearFakeQuant from .utility import get_module, quant_tensor, set_module +__all__ = ["teq_quantize", "TEQuantizer"] + class TEQuantizer: """Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input.""" @@ -38,7 +40,8 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex self.weight_config = weight_config self.folding = folding self.example_inputs = example_inputs - self.device, self.dtype = self._get_device() + self.device = self._get_device() + self.dtype = self._get_dtype() self.model.eval() self.trained_alphas = {} self.absorb_to_layer = absorb_to_layer @@ -46,8 +49,13 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex def _get_device(self): """Get the model device :return:Model device.""" + device = get_device() + self.model.to(device) + return device + + def _get_dtype(self): for _, p in self.model.named_parameters(): - return p.data.device, p.data.dtype + return p.data.dtype def add_tuning_scale(self, sqrt_w_init=False): """The main entry of smooth quant diff --git a/neural_compressor/torch/utils/auto_accelerator.py b/neural_compressor/torch/utils/auto_accelerator.py index 24c121440bb..57af493b738 100644 --- a/neural_compressor/torch/utils/auto_accelerator.py +++ b/neural_compressor/torch/utils/auto_accelerator.py @@ -19,7 +19,6 @@ # NOTICE: The design adapted from: # https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py. -# TODO: move it into torch/utils # To keep it simply, only add the APIs we need. @@ -204,12 +203,19 @@ def empty_cache(self): def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: + # Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ... + # The `FORCE_DEVICE` is case insensitive. # The environment variable `FORCE_DEVICE` has higher priority than the `device_name`. # TODO: refine the docs and logic later + # 1. Get the device setting from environment variable `FORCE_DEVICE`. FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None) + if FORCE_DEVICE: + FORCE_DEVICE = FORCE_DEVICE.lower() + # 2. If the `FORCE_DEVICE` is set and the accelerator is available, use it. if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None: logger.warning("Force use %s accelerator.", FORCE_DEVICE) return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)() + # 3. If the `device_name` is set and the accelerator is available, use it. if device_name != "auto": if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None: accelerator_cls = accelerator_registry.get_accelerator_cls_by_name(device_name) @@ -217,6 +223,7 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: return accelerator_cls() else: logger.warning("The device name %s is not supported, use auto detect instead.", device_name) + # 4. Select the accelerator by priority. for accelerator_cls in accelerator_registry.get_sorted_accelerators(): if accelerator_cls.is_available(): logger.warning("Auto detect accelerator: %s.", accelerator_cls.__name__) @@ -227,4 +234,6 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: # Force use cpu accelerator even if cuda is available. # FORCE_DEVICE = "cpu" python ... # or +# FORCE_DEVICE = "CPU" python ... +# or # CUDA_VISIBLE_DEVICES="" python ... diff --git a/test/3x/torch/quantization/weight_only/test_woq_on_cuda.py b/test/3x/torch/quantization/weight_only/test_woq_on_cuda.py index c64f748c754..910b8682186 100644 --- a/test/3x/torch/quantization/weight_only/test_woq_on_cuda.py +++ b/test/3x/torch/quantization/weight_only/test_woq_on_cuda.py @@ -7,7 +7,22 @@ from neural_compressor.common.utils import logger from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device -from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize +from neural_compressor.torch.quantization import ( + AWQConfig, + GPTQConfig, + get_default_awq_config, + get_default_rtn_config, + get_default_teq_config, + quantize, +) + + +def get_gpt_j(): + tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + return tiny_gptj class GPTQDataloaderPreprocessor: @@ -213,9 +228,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): pass return - user_model = transformers.AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-GPTJForCausalLM", - ) + user_model = get_gpt_j() user_model = quantize( model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration @@ -228,9 +241,7 @@ class TestRTNQuant: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") def test_rtn(self): - self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-GPTJForCausalLM", - ) + self.tiny_gptj = get_gpt_j() self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long) model = self.tiny_gptj # record label for comparison @@ -239,3 +250,129 @@ def test_rtn(self): quant_config = get_default_rtn_config() q_model = quantize(model, quant_config) assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +class TestAWQOnCuda: + + def test_awq(self): + self.lm_input = torch.ones([1, 10], dtype=torch.long) + self.gptj = get_gpt_j() + example_inputs = torch.ones([1, 10], dtype=torch.long) + + def calib_func(model): + for i in range(2): + model(self.lm_input.to(model.device)) + + quant_config = get_default_awq_config() + logger.info("Test quantization with config", quant_config) + q_model = quantize( + model=self.gptj, quant_config=quant_config, example_inputs=self.lm_input, run_fn=calib_func, inplace=False + ) + out2 = q_model(example_inputs.to(q_model.device)) + assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}" + assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}" + + +def generate_random_corpus(nsamples=32): + meta_data = [] + for _ in range(nsamples): + inp = torch.ones([1, 512], dtype=torch.long) + tar = torch.ones([1, 512], dtype=torch.long) + meta_data.append((inp, tar)) + return meta_data + + +def train( + model, + train_steps=1000, + lr=1e-3, + warmup_ratio=0.05, + gradient_accumulation_steps=1, + logging_steps=10, + betas=[0.9, 0.9], + weight_decay=0, + lr_scheduler_type="linear", +): + """Train function.""" + trained_alphas_list = [torch.ones([128], requires_grad=True)] + optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas) + + lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111 + name=lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps, + num_training_steps=train_steps // gradient_accumulation_steps, + ) + + logger.info("start training") + model.train() + global_steps = 0 + dataloader = generate_random_corpus() + while global_steps <= train_steps: + for inputs in dataloader: + if isinstance(inputs, torch.Tensor): + input_id = inputs + elif isinstance(inputs, dict): + input_id = inputs["input_ids"] + else: + input_id = inputs[0] + output = model(input_id.to(model.device), labels=input_id.to(model.device)) + loss = output[0] / gradient_accumulation_steps + loss.backward() + global_steps += 1 + + if global_steps % logging_steps == 0: + logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item())) + + if global_steps % gradient_accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + if global_steps >= train_steps: # pragma: no cover + break + + logger.info("finish training") + model.eval() + return None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +class TestTEQOnCuda: + + def test_teq(self): + quant_config = { + "teq": { + "global": { + "dtype": "fp32", + }, + "local": { + "transformer.h.0.mlp.fc_in": { + "dtype": "int", + "bits": 8, + "group_size": -1, + "use_sym": True, + "folding": True, + "absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]}, + }, + "transformer.h.0.mlp.fc_out": { + "dtype": "int", + "bits": 4, + "group_size": 32, + "use_sym": False, + "folding": True, + "absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]}, + }, + }, + } + } + example_inputs = torch.ones([1, 512], dtype=torch.long) + test_input = torch.ones([1, 512], dtype=torch.long) + model = get_gpt_j() + + qdq_model = quantize(model=model, quant_config=quant_config, run_fn=train, example_inputs=example_inputs) + assert isinstance(qdq_model, torch.nn.Module), "Expect qdq_model is a torch module" + out2 = qdq_model(test_input.to(qdq_model.device)) + assert "cuda" in str(qdq_model.device), f"Expect qmodel device is cuda, got {qdq_model.device}" + assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}"