Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to perform dynamic quantization on MultiheadAttention #1416

Closed
nagbhat25 opened this issue Nov 25, 2023 · 8 comments
Closed

Unable to perform dynamic quantization on MultiheadAttention #1416

nagbhat25 opened this issue Nov 25, 2023 · 8 comments
Assignees

Comments

@nagbhat25
Copy link

nagbhat25 commented Nov 25, 2023

Hello

I am trying to perform dynamic quantization on GeolayoutLM model which internally uses the torch.nn.MultiHeadAttention layer. When I try to quantize this model using dynamic quantization, I get an error in torch/nn/modules/activation.py . I think it is mostly because it uses NonDynamicallyQuantizableLinear internally.

I would like to know if there is any way to get around this, or is it totally not supported. Is there a way to skip a layer in quantization. (my knowledge is very limited in this)

model link - geolayoutlm

Any help would be appreciated. Thanks

@Kaihui-intel
Copy link
Contributor

Hi @nagbhat25,

Thanks for raising this issue.
What is the model and dataset you are using?
Can you provide scripts or code and the commands you are using for us to reproduce?

@nagbhat25
Copy link
Author

nagbhat25 commented Dec 2, 2023

Hi @Kaihui-intel , Thanks for taking a look into this.

The model I used is defined here in class GeoLayoutLMVIEModel: Link to model

I am trying to run it just run dynamic quantisation (which I think can run without any data). Here is the sample snippet I use

from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor import quantization

config = PostTrainingQuantConfig(device='cpu', approach='dynamic', domain='auto')
q_model = quantization.fit(gllm_model, config)

The quantisation goes through and I get below log:

|***********Mixed Precision Statistics***********|
2023-12-02 06:03:05 [INFO] +---------------------------------+-------+------+
2023-12-02 06:03:05 [INFO] |             Op Type             | Total | INT8 |
2023-12-02 06:03:05 [INFO] +---------------------------------+-------+------+
2023-12-02 06:03:05 [INFO] |            Embedding            |   9   |  9   |
2023-12-02 06:03:05 [INFO] |              Linear             |  194  | 194  |
2023-12-02 06:03:05 [INFO] | NonDynamicallyQuantizableLinear |   6   |  6   |
2023-12-02 06:03:05 [INFO] +---------------------------------+-------+------+

However, when I try to infer by calling the model I get the below error, (possibly from torch module). Looks like there isn't a support for NonDynamicallyQuantizableLinear layer.

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:280, in TransformerEncoder.forward(self, src, mask, src_key_padding_mask)
    277         src_key_padding_mask_for_layers = None
    279 for mod in self.layers:
--> 280     output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
    282 if convert_to_nested:
    283     output = output.to_padded_tensor(0.)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:538, in TransformerEncoderLayer.forward(self, src, src_mask, src_key_padding_mask)
    536     x = x + self._ff_block(self.norm2(x))
    537 else:
--> 538     x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
    539     x = self.norm2(x + self._ff_block(x))
    541 return x

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:546, in TransformerEncoderLayer._sa_block(self, x, attn_mask, key_padding_mask)
    544 def _sa_block(self, x: Tensor,
    545               attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
--> 546     x = self.self_attn(x, x, x,
    547                        attn_mask=attn_mask,
    548                        key_padding_mask=key_padding_mask,
    549                        need_weights=False)[0]
    550     return self.dropout1(x)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/activation.py:1171, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights)
   1156     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1157         query, key, value, self.embed_dim, self.num_heads,
   1158         self.in_proj_weight, self.in_proj_bias,
   (...)
   1164         q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
   1165         v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
   1166 else:
   1167     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1168         query, key, value, self.embed_dim, self.num_heads,
   1169         self.in_proj_weight, self.in_proj_bias,
   1170         self.bias_k, self.bias_v, self.add_zero_attn,
-> 1171         self.dropout, self.out_proj.weight, self.out_proj.bias,
   1172         training=self.training,
   1173         key_padding_mask=key_padding_mask, need_weights=need_weights,
   1174         attn_mask=attn_mask, average_attn_weights=average_attn_weights)
   1175 if self.batch_first and is_batched:
   1176     return attn_output.transpose(1, 0), attn_output_weights

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1269, in Module.__getattr__(self, name)
   1267     if name in modules:
   1268         return modules[name]
-> 1269 raise AttributeError("'{}' object has no attribute '{}'".format(
   1270     type(self).__name__, name))

AttributeError: 'GraphModule' object has no attribute 'bias'

If there is a way to avoid quantizing this layer altogether, that might solve this.

@Kaihui-intel
Copy link
Contributor

I have verified it's reproducible, and I am working on it.
there is a way to avoid quantizing this layer:

    op_type_dict = {
        'NonDynamicallyQuantizableLinear':{
            "weight": {
                "dtype": ["fp32"]
            },
            "activation": {
                "dtype": ["fp32"]
            }
        }
    }
    config = PostTrainingQuantConfig(device='cpu', approach='dynamic', domain='auto', op_type_dict=op_type_dict)

@nagbhat25
Copy link
Author

Hi @Kaihui-intel ,

I did try this already but no luck. The quantization succeeds by untouching the configured layer weights but the inference still fails. Here is the error reports:

Dynamic Quantization report:

2023-12-05 15:10:59 [INFO]  Found 24 blocks
2023-12-05 15:10:59 [INFO] Attention Blocks: 24
2023-12-05 15:10:59 [INFO] FFN Blocks: 24
2023-12-05 15:10:59 [INFO] Pass query framework capability elapsed time: 2588.14 ms
2023-12-05 15:10:59 [INFO] Do not evaluate the baseline and quantize the model with default configuration.
2023-12-05 15:10:59 [INFO] Quantize the model with default config.
2023-12-05 15:10:59 [INFO] Fx trace of the entire model failed, We will conduct auto quantization
2023-12-05 15:11:24 [INFO] |***************Mixed Precision Statistics**************|
2023-12-05 15:11:24 [INFO] +---------------------------------+-------+------+------+
2023-12-05 15:11:24 [INFO] |             Op Type             | Total | INT8 | FP32 |
2023-12-05 15:11:24 [INFO] +---------------------------------+-------+------+------+
2023-12-05 15:11:24 [INFO] |            Embedding            |   9   |  9   |  0   |
2023-12-05 15:11:24 [INFO] |              Linear             |  194  | 194  |  0   |
2023-12-05 15:11:24 [INFO] | NonDynamicallyQuantizableLinear |   6   |  0   |  6   |
2023-12-05 15:11:24 [INFO] +---------------------------------+-------+------+------+
2023-12-05 15:11:24 [INFO] Pass quantize model elapsed time: 25082.02 ms
2023-12-05 15:11:24 [INFO] Save tuning history to /app/services/docintel/docintel_v25_1/docintel_models/document_ai/geolayoutlm/nc_workspace/2023-12-05_15-10-53/./history.snapshot.
2023-12-05 15:11:24 [INFO] [Strategy] Found the model meets accuracy requirements, ending the tuning process.
2023-12-05 15:11:24 [INFO] Specified timeout or max trials is reached! Found a quantized model which meet accuracy goal. Exit.

Errro during inference with op_type_dict change:

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:387, in TransformerEncoder.forward(self, src, mask, src_key_padding_mask, is_causal)
    384 is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
    386 for mod in self.layers:
--> 387     output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
    389 if convert_to_nested:
    390     output = output.to_padded_tensor(0., src.size())

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:707, in TransformerEncoderLayer.forward(self, src, src_mask, src_key_padding_mask, is_causal)
    705     x = x + self._ff_block(self.norm2(x))
    706 else:
--> 707     x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
    708     x = self.norm2(x + self._ff_block(x))
    710 return x

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/transformer.py:715, in TransformerEncoderLayer._sa_block(self, x, attn_mask, key_padding_mask, is_causal)
    713 def _sa_block(self, x: Tensor,
    714               attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
--> 715     x = self.self_attn(x, x, x,
    716                        attn_mask=attn_mask,
    717                        key_padding_mask=key_padding_mask,
    718                        need_weights=False, is_causal=is_causal)[0]
    719     return self.dropout1(x)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/activation.py:1245, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
   1227     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1228         query, key, value, self.embed_dim, self.num_heads,
   1229         self.in_proj_weight, self.in_proj_bias,
   (...)
   1238         average_attn_weights=average_attn_weights,
   1239         is_causal=is_causal)
   1240 else:
   1241     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1242         query, key, value, self.embed_dim, self.num_heads,
   1243         self.in_proj_weight, self.in_proj_bias,
   1244         self.bias_k, self.bias_v, self.add_zero_attn,
-> 1245         self.dropout, self.out_proj.weight, self.out_proj.bias,
   1246         training=self.training,
   1247         key_padding_mask=key_padding_mask,
   1248         need_weights=need_weights,
   1249         attn_mask=attn_mask,
   1250         average_attn_weights=average_attn_weights,
   1251         is_causal=is_causal)
   1252 if self.batch_first and is_batched:
   1253     return attn_output.transpose(1, 0), attn_output_weights

File /opt/app_venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1695, in Module.__getattr__(self, name)
   1693     if name in modules:
   1694         return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'GraphModule' object has no attribute 'bias'

Thanks for looking into this!

@Kaihui-intel
Copy link
Contributor

Hello @nagbhat25,
We utilized pytorch_fx to perform quantization on NonDynamicallyQuantizableLinear. However, the GraphModule did not have the bias attribute set for it. We have now resolved this issue by Add bias for fx_model.
Now you can use the initial code for quantization and inference.

@nagbhat25
Copy link
Author

nagbhat25 commented Dec 7, 2023

Thanks @Kaihui-intel . I tested the fix using your branch and the inference works perfectly fine now.

The model size reduces by 60% or so but However, one thing I notice is despite converting to int8, there is no gain in inference time. When I compare the inference time numbers to that of fp32 model, its almost same (in some cases the quantized model is slightly higher too). Is there any reason why this would happen or any flags that I can tweak.

Here is the current quantization code:

from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor import quantization

config = PostTrainingQuantConfig(device='cpu', approach='dynamic', domain='auto', op_type_dict=op_type_dict)
q_model = quantization.fit(model, config)

Thanks a lot for looking into this.

@Kaihui-intel
Copy link
Contributor

Kaihui-intel commented Dec 8, 2023

Thank you for your feedback.

  • For a smaller model size, you can remove "op_type_dict" just like your initial code.
    config = PostTrainingQuantConfig(device='cpu', approach='dynamic', domain='auto')
  • For the inference time, We can use a torch profiler to observe the inference time of the model operators. This model mainly quantize the linear op.:
                                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   quantized::linear_dynamic        30.81%        2.349s        30.87%        2.353s      12.257ms           192  
                                aten::einsum         8.36%     637.247ms        18.36%        1.400s      53.829ms            26  
                                 aten::copy_        15.74%        1.200s        15.74%        1.200s       4.083ms           294  
                                 aten::clone         0.02%       1.399ms        14.36%        1.095s       6.187ms           177  
                                aten::linear         0.00%     194.000us        10.45%     796.888ms      53.126ms            15  
                               aten::reshape         0.01%     764.000us        10.04%     765.596ms       4.350ms           176  
                                aten::matmul         1.35%     102.992ms         7.88%     600.824ms      10.729ms            56  
                                   aten::cat         6.11%     465.909ms         6.11%     466.173ms      29.136ms            16  
                            aten::layer_norm         0.00%     380.000us         5.77%     439.950ms       5.714ms            77  
                     aten::native_layer_norm         4.24%     322.890ms         5.77%     439.570ms       5.709ms            77  
                                   aten::add         5.70%     434.274ms         5.70%     434.343ms       2.767ms           157  
              quantized::linear_relu_dynamic         4.62%     352.166ms         4.62%     352.231ms     117.410ms             3  
                                   aten::bmm         4.53%     345.292ms         4.53%     345.293ms       4.263ms            81  
                            aten::contiguous         0.01%     766.000us         4.35%     331.321ms       4.872ms            68  

Profile output of fp32 model:

--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::linear         0.04%       3.021ms        38.46%        2.750s      13.097ms           210

From the above, it can be seen that there is not much difference in inference time between the two operators. There is a known issue here about torch quantized::linear_dynamic

@nagbhat25
Copy link
Author

Thanks @Kaihui-intel . I did do some profiling for different quantization approaches and this seems to be the main issue. Hope the community provides some solution to this in future releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants