-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
INT4 XPU enabling #1577
base: main
Are you sure you want to change the base?
INT4 XPU enabling #1577
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1577
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -46,6 +47,18 @@ def get_quantization_functions( | |||
zero_point_domain=ZeroPointDomain.INT, | |||
) | |||
) | |||
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2_7 or 2_6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
depend on the pytorch/pytorch#137566. It's really just a draft, so not yet ready for review. I will ping you when ready :)
@@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): | |||
layout_list = [] | |||
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: | |||
layout_list.append(Int4CPULayout()) | |||
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here as well, 2_6 or 2_7?
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw for this one, we have some unpacking op for tensor core tiled layout that we should really be using:
ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu
Lines 311 to 312 in cf45336
m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout); | |
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout); |
might be better to do the same instead of hacking with quantize ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. I will give a check.
btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao? |
@mingfeima @EikanWang can you comment? |
The situation is different for XPU (the intel GPUs) from CPU and CUDA here. Not sure that whether providing sycl or oneDNN xpu ops in ao is a feasible solution. |
The PR is a draft currently.
The PR will add 2 kinds of INT4 support on XPU: floating zero points and integer zero points, following the discussion in #1264.
Integer zero points which is natively supported via OneDNN, is planned to be merged into PyTorch main repo pytorch/pytorch#137566
Floating zero points, the default behaviour in this repo, the initial work has been done in XPU operators intel/torch-xpu-ops#1130, more implementations on the way.