Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INT4 XPU enabling #1577

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.dtypes import CutlassInt4PackedLayout, Int4XPULayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand All @@ -20,6 +20,7 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
is_sm_at_least_89,
)

Expand All @@ -46,6 +47,18 @@ def get_quantization_functions(
zero_point_domain=ZeroPointDomain.INT,
)
)
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2_7 or 2_6?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depend on the pytorch/pytorch#137566. It's really just a draft, so not yet ready for review. I will ping you when ready :)

base_functions.append(
int4_weight_only(group_size=32, layout=Int4XPULayout())
)
if int4_zp_int:
base_functions.append(
int4_weight_only(
group_size=32,
layout=Int4XPULayout(),
zero_point_domain=ZeroPointDomain.INT,
)
)
else:
base_functions.append(int4_weight_only(group_size=32))
if device == "cuda":
Expand Down
9 changes: 8 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch._inductor.utils import run_and_get_code

import torchao
from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
from torchao.dtypes.utils import is_device
from torchao.quantization import safe_int_mm
from torchao.quantization.autoquant import (
Expand Down Expand Up @@ -139,6 +139,11 @@ def _int4wo_api(mod):
mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False
)
unwrap_tensor_subclass(mod)
elif is_device(next(mod.parameters()).device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7:
quantize_(
mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False
)
unwrap_tensor_subclass(mod)
elif TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
layout_list = []
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
layout_list.append(Int4CPULayout())
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here as well, 2_6 or 2_7?

layout_list.append(Int4XPULayout())
else:
for inner_k_tiles in [4, 2]:
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
Expand Down
9 changes: 6 additions & 3 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
is_fbcode,
)

Expand Down Expand Up @@ -130,7 +131,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
)

if TORCH_VERSION_AT_LEAST_2_5:
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
if (not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)) \
and ((not is_device(w.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7)):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8
Expand Down Expand Up @@ -739,8 +741,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (
is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6
if (not (
is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)
and (not is_device(input.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7)
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BlockSparseLayout,
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
MarlinSparseLayout,
Expand Down Expand Up @@ -52,4 +53,5 @@
"MarlinQQQLayout",
"Int4CPULayout",
"CutlassInt4PackedLayout",
"Int4XPULayout",
]
4 changes: 4 additions & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .int4_cpu_layout import (
Int4CPULayout,
)
from .int4_xpu_layout import (
Int4XPULayout,
)
from .marlin_qqq_tensor import (
MarlinQQQLayout,
MarlinQQQTensor,
Expand Down Expand Up @@ -36,4 +39,5 @@
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"CutlassInt4PackedLayout",
"Int4XPULayout"
]
Loading