diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 5909de644ac..1dfc27c1237 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability): q_capability["opwise"][bf16_op] = [bf16_config, fp32_config] if bf16_op[1] not in q_capability["optypewise"]: q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config] + if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]: + q_capability["optypewise"][bf16_op[1]].append(bf16_config) return q_capability def get_fused_list(self, model): diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1x.py index b13c6ff5a76..be1accfaaf6 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1x.py @@ -1103,16 +1103,18 @@ def test_bf16_capability(self): os.environ["FORCE_BF16"] = "1" q_capability = self.adaptor._get_quantizable_ops(model_origin) del os.environ["FORCE_BF16"] - - self.assertEqual([elem["weight"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["int8"], "fp32"]) self.assertEqual( - [elem["activation"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["uint8"], "fp32"] + [elem["weight"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["int8"], "fp32", "bf16"] + ) + self.assertEqual( + [elem["activation"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["uint8"], "fp32", "bf16"] ) self.assertEqual( - [elem["weight"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["int8"], "fp32"] + [elem["weight"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["int8"], "fp32", "bf16"] ) self.assertEqual( - [elem["activation"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["uint8"], "fp32"] + [elem["activation"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], + [["uint8"], "fp32", "bf16"], ) self.assertEqual( [elem["weight"]["dtype"] for elem in q_capability["opwise"][("linear", "Linear")]],