Skip to content

Commit

Permalink
fix 8bit choosing wrong packer (ModelCloud#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Jul 10, 2024
1 parent 223082a commit 2aee89d
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
5 changes: 4 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,15 @@ def tmp(_, inp, out):
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=BACKEND.AUTO,
# TODO: use triton for packing always? since it can support [2,4,8] bits while exllama only supports 4bits
# triton can support 2, 4, 8bits while exllama packer only supports 4bits
backend=BACKEND.TRITON if not isinstance(self.quantize_config, AutoRoundQuantizeConfig) and self.quantize_config.format in [FORMAT.GPTQ, FORMAT.GPTQ_V2] and self.quantize_config.bits != 4 else BACKEND.AUTO,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
format=self.quantize_config.format,
)

if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
self.model = simple_dispatch_model(self.model, device_map)
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
})

format_dict = {
FORMAT.GPTQ: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA],
FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA],
FORMAT.GPTQ: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA, BACKEND.TRITON],
FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA, BACKEND.TRITON],
FORMAT.MARLIN: [BACKEND.MARLIN],
FORMAT.BITBLAS: [BACKEND.BITBLAS],
FORMAT.QBITS: [BACKEND.QBITS],
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def make_quant(
bits: int,
group_size: int,
backend: BACKEND,
format: str,
format: str | FORMAT,
desc_act: bool = False,
sym: bool = True,
pack: bool = False,
Expand Down
18 changes: 10 additions & 8 deletions tests/test_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,26 @@ def calculate_native_ppl(self, format):

@parameterized.expand(
[
(QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2),
(QUANT_METHOD.GPTQ, FORMAT.GPTQ),
(QUANT_METHOD.GPTQ, FORMAT.MARLIN),
(QUANT_METHOD.GPTQ, FORMAT.BITBLAS),
(QUANT_METHOD.AUTO_ROUND, FORMAT.GPTQ),
(QUANT_METHOD.GPTQ, FORMAT.GPTQ, 8),
(QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 8),
(QUANT_METHOD.GPTQ, FORMAT.GPTQ_V2, 4),
(QUANT_METHOD.GPTQ, FORMAT.GPTQ, 4),
(QUANT_METHOD.GPTQ, FORMAT.MARLIN, 4),
(QUANT_METHOD.GPTQ, FORMAT.BITBLAS, 4),
(QUANT_METHOD.AUTO_ROUND, FORMAT.GPTQ, 4),
]
)
def test_quantized_perplexity(self, method: QUANT_METHOD, format: FORMAT):
def test_quantized_perplexity(self, method: QUANT_METHOD, format: FORMAT, bits: int):
if method == QUANT_METHOD.GPTQ:
quantize_config = QuantizeConfig(
bits=4,
bits=bits,
group_size=128,
format=format,
desc_act=False if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else True
)
elif method == QUANT_METHOD.AUTO_ROUND:
quantize_config = AutoRoundQuantizeConfig(
bits=4,
bits=bits,
group_size=128,
format=format,
)
Expand Down
21 changes: 11 additions & 10 deletions tests/test_quant_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,26 @@ def setUpClass(self):

@parameterized.expand(
[
(QUANT_METHOD.GPTQ, BACKEND.QBITS, False, FORMAT.GPTQ),
(QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ_V2),
(QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, False, FORMAT.GPTQ),
(QUANT_METHOD.GPTQ, BACKEND.MARLIN, True, FORMAT.MARLIN),
(QUANT_METHOD.AUTO_ROUND, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ),
(QUANT_METHOD.GPTQ, BACKEND.AUTO, False, FORMAT.GPTQ, 8),
(QUANT_METHOD.GPTQ, BACKEND.QBITS, False, FORMAT.GPTQ, 4),
(QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ_V2, 4),
(QUANT_METHOD.GPTQ, BACKEND.EXLLAMA_V2, False, FORMAT.GPTQ, 4),
(QUANT_METHOD.GPTQ, BACKEND.MARLIN, True, FORMAT.MARLIN, 4),
(QUANT_METHOD.AUTO_ROUND, BACKEND.EXLLAMA_V2, True, FORMAT.GPTQ, 4),
]
)
def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, format: FORMAT):
def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, format: FORMAT, bits: int):
if method == QUANT_METHOD.GPTQ:
quantize_config = QuantizeConfig(
bits=4,
bits=bits,
group_size=128,
desc_act=False if format == FORMAT.MARLIN else True,
sym=sym,
format=format,
)
elif method == QUANT_METHOD.AUTO_ROUND:
quantize_config = AutoRoundQuantizeConfig(
bits=4,
bits=bits,
group_size=128,
sym=sym,
format=format,
Expand Down Expand Up @@ -99,7 +100,7 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma

# test compat: 1) with simple dict type 2) is_marlin_format
compat_quantize_config = {
"bits": 4,
"bits": bits,
"group_size": 128,
"sym": sym,
"desc_act": False if format == FORMAT.MARLIN else True,
Expand All @@ -120,7 +121,7 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma
os.remove(f"{tmpdirname}/{QUANT_CONFIG_FILENAME}")

compat_quantize_config = {
"bits": 4,
"bits": bits,
"group_size": 128,
"sym": sym,
"desc_act": False if format == FORMAT.MARLIN else True,
Expand Down

0 comments on commit 2aee89d

Please sign in to comment.