diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 42d4f948cd..d5f2dca5b4 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -365,6 +365,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, + device=next(child.parameters()).device, ) setattr(module, name, quantized_linear)