Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bc-breaking] enable direct configuration in quantize_ #1595

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
run_tests,
)

from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
Expand Down Expand Up @@ -80,7 +82,8 @@ def test_tensor_core_layout_transpose(self):
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(linear)
quantize_(linear, apply_int4_weight_only_quant)
ql = linear
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -100,7 +103,11 @@ def test_tensor_core_layout_transpose(self):
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand Down Expand Up @@ -178,8 +185,13 @@ def apply_uint6_weight_only_quant(linear):
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
print(apply_quant)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -193,7 +205,10 @@ def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
else:
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand Down
7 changes: 4 additions & 3 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MappingType,
ZeroPointDomain,
int4_weight_only,
quantize_,
uintx_weight_only,
)
from torchao.utils import (
Expand Down Expand Up @@ -51,9 +52,9 @@ def _eval_hqq(dtype):
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:
Expand Down
25 changes: 25 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -761,6 +762,30 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)

sqnr = compute_error(y_ref, y_int4_wo)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
Empty file added torchao/core/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import abc


class AOBaseConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
how to a apply it to a model.
"""

pass
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from torchao.kernel import (
int_scaled_matmul,
safe_int_mm,
Expand Down Expand Up @@ -84,6 +85,7 @@
swap_linear_with_smooth_fq_linear,
)
from .subclass import * # noqa: F403
from .transform_module import register_quantize_module_handler
from .unified import Quantizer, TwoStepQuantizer
from .utils import (
compute_error,
Expand Down Expand Up @@ -142,6 +144,8 @@
# operators/kernels
"safe_int_mm",
"int_scaled_matmul",
# registration of module transforms for quantize_
"register_quantize_module_handler",
# dataclasses and types
"MappingType",
"ZeroPointDomain",
Expand Down
114 changes: 68 additions & 46 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
from typing import Any, List, Optional, Union

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.granularity import (
Granularity,
PerAxis,
Expand All @@ -22,6 +23,9 @@
TorchAODType,
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.unified import TwoStepQuantizer


Expand Down Expand Up @@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)


def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> Callable:
@dataclass
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
activation_config: Optional[FakeQuantizeConfig] = None
weight_config: Optional[FakeQuantizeConfig] = None


# for BC
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig


@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
def _intx_quantization_aware_training_transform(
module: torch.nn.Module,
config: IntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
can break at any time.
Apply fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
Expand All @@ -267,37 +285,32 @@ def intx_quantization_aware_training(
`torch.nn.Embedding` with an activation config, then we will raise
ValueError as these are not supported.
"""

def _insert_fake_quantize(mod: torch.nn.Module):
"""
Swap the given module with its corresponding fake quantized version.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

mod = module
activation_config = config.activation_config
weight_config = config.weight_config

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Module of type '%s' does not have QAT support" % type(mod)
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))

return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
"""
Return a function that converts a model with fake quantized modules,
Object that knows how to convert a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
Expand All @@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear
pass


# for BC
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
def _from_intx_quantization_aware_training_transform(
mod: torch.nn.Module,
config: FromIntXQuantizationAwareTrainingConfig,
) -> torch.nn.Module:
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading
Loading