diff --git a/auto_round/backend.py b/auto_round/backend.py index 5e6f1f88..fe05cd79 100644 --- a/auto_round/backend.py +++ b/auto_round/backend.py @@ -132,15 +132,15 @@ def check_auto_round_exllamav2_installed(): "auto_gptq:tritonv2"], requirements=["auto-gptq>=0.7.1", "triton<3.0,>=2.0"] ) -# -# BackendInfos['gptq:cuda'] = BackendInfo(device=["cuda"], sym=[True, False], -# packing_format="triton_zp+-1", -# bits=[2, 3, 4, 8], group_size=None, -# priority=1, feature_checks=[feature_multiply_checker_32], -# alias=["auto_round:auto_gptq:cuda,auto_gptq:cuda, auto_round:gptq:cuda"], -# convertable_format=["triton_zp+-1"], -# requirements=["auto-gptq>=0.7.1"] -# ) + +BackendInfos['gptq:cuda'] = BackendInfo(device=["cuda"], sym=[True, False], + packing_format="triton_zp+-1", + bits=[2, 3, 4, 8], group_size=None, + priority=1, feature_checks=[feature_multiply_checker_32], + alias=["auto_round:auto_gptq:cuda,auto_gptq:cuda, auto_round:gptq:cuda"], + convertable_format=["triton_zp+-1"], + requirements=["auto-gptq>=0.7.1"] + ) BackendInfos['awq:gemm'] = BackendInfo(device=["cuda"], sym=[True, False], ##actrally is gemm packing_format="awq",