Skip to content

Commit

Permalink
fix UT bug
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jul 8, 2024
1 parent 8201841 commit 3ace9e6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability):
q_capability["opwise"][bf16_op] = [bf16_config, fp32_config]
if bf16_op[1] not in q_capability["optypewise"]:
q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config]
if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]:
q_capability["optypewise"][bf16_op[1]].append(bf16_config)
return q_capability

def get_fused_list(self, model):
Expand Down
12 changes: 7 additions & 5 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,16 +1103,18 @@ def test_bf16_capability(self):
os.environ["FORCE_BF16"] = "1"
q_capability = self.adaptor._get_quantizable_ops(model_origin)
del os.environ["FORCE_BF16"]

self.assertEqual([elem["weight"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["int8"], "fp32"])
self.assertEqual(
[elem["activation"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["uint8"], "fp32"]
[elem["weight"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["int8"], "fp32", "bf16"]
)
self.assertEqual(
[elem["activation"]["dtype"] for elem in q_capability["optypewise"]["Conv2d"]], [["uint8"], "fp32", "bf16"]
)
self.assertEqual(
[elem["weight"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["int8"], "fp32"]
[elem["weight"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["int8"], "fp32", "bf16"]
)
self.assertEqual(
[elem["activation"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]], [["uint8"], "fp32"]
[elem["activation"]["dtype"] for elem in q_capability["opwise"][("conv", "Conv2d")]],
[["uint8"], "fp32", "bf16"],
)
self.assertEqual(
[elem["weight"]["dtype"] for elem in q_capability["opwise"][("linear", "Linear")]],
Expand Down

0 comments on commit 3ace9e6

Please sign in to comment.