Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add separate quantization primitives for float8 #1597

Merged
merged 12 commits into from
Jan 25, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jan 22, 2025

Context

Currently, AQT has the method from_hp_to_floatx for float8 quantization, and from_hp_to_fpx for low precision floating point data types like fp6 (technically can support fp1-fp7).

from_hp_to_floatx re-uses from_hp_to_intx, which in turn uses these generic quantization primitives.

Overall, in the current state the float8 path is a bit confusing for developers, due to both the naming ("floatx") and the use of generic functions which include a bunch of params which are unrelated to float8 quantization.

Summary of changes

The goal of this PR stack is to refactor this to have a clean separation of concerns, and simpler internal API surfaces for float8 quantization for inference.

Specifically:

  • Separate quantization primitives for float8 <------------------- (this PR)
  • Integrate those new quant primitives into AQT

Note: I will add float8 static quantization in a separate set of PRs.

[ghstack-poisoned]
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Jan 22, 2025

Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1597

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2838d50 with merge base 860da26 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 22, 2025
danielvegamyhre added a commit that referenced this pull request Jan 22, 2025
ghstack-source-id: b5340379b49bdab5e00e4c27d7444dc7d7f1acd7
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
[ghstack-poisoned]
danielvegamyhre added a commit that referenced this pull request Jan 22, 2025
ghstack-source-id: 51628a9d0c9bcdc03a77b1ddcb5ab002f49f856e
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
@danielvegamyhre danielvegamyhre added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) quantize and removed quantize labels Jan 22, 2025
[ghstack-poisoned]
danielvegamyhre added a commit that referenced this pull request Jan 22, 2025
ghstack-source-id: 50780aa701de01474ce520235f576909528141c6
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
@danielvegamyhre danielvegamyhre requested review from jainapurva and jerryzh168 and removed request for jainapurva January 22, 2025 19:36
@jainapurva jainapurva requested review from vkuzo and drisspg January 23, 2025 22:57
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@@ -1300,3 +1303,67 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Does these look good?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively I could just make these functions wrappers around the generic primitives which pass in the appropriate params for float8, as that may be more maintainable - although it does add a step of indirection and hides how the scale is actually computed. Any thoughts?

For example:

def choose_qparams_float8(input: torch.Tensor, float8_dtype: torch.dtype):
    scale, _ = choose_qparams_affine(
            input,
            MappingType.SYMMETRIC,
            input.shape,     # only tensorwise scaling is supported at the moment
            float8_dtype,
            eps=float8_eps,  # use same EPS as float8 training
            scale_dtype=torch.float32,
            quant_min=torch.finfo(float8_dtype).min,
            quant_max=torch.finfo(float8_dtype).max,
        )
    return scale

[ghstack-poisoned]
@danielvegamyhre danielvegamyhre merged commit 47f96f1 into main Jan 25, 2025
41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. quantize topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants