Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize FBGEMM Triton MX4 Dequantize #2837

Closed
wants to merge 3 commits into from

Commits on Jul 12, 2024

  1. Use better exponent rounding in Triton MX4 quantize kernel (pytorch#2816

    )
    
    Summary:
    X-link: facebookresearch/FBGEMM#20
    
    Pull Request resolved: pytorch#2816
    
    As noted in [this doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr), using a ceiling round for scale calculation does a better job of not truncating some mantissa bits. This diff switches triton's floor rounding to ceil rounding.
    
    Note that currently mx4_test doesnt pass as the cuda kernel now has different behavior than triton. Once we rebase this diff onto a similar change to the cuda kernel, we should see exact matching outputs again.
    
    Differential Revision: D59527463
    
    Reviewed By: jianyuh
    Josh Fromm authored and facebook-github-bot committed Jul 12, 2024
    Configuration menu
    Copy the full SHA
    49d8bc6 View commit details
    Browse the repository at this point in the history
  2. Refactor MX4 Kernel to operate on flat tensors (pytorch#2836)

    Summary:
    Pull Request resolved: pytorch#2836
    
    Rather than try to reshape inputs to 2D matrices with each thread operating on one row, this refactor uses 1D inputs and has each thread operate on an offset of the array.
    
    The main benefit of this is that it avoid ragged tensors where we cant divide an input into even sized rows. This should enable us to be compatible with more shapes.
    
    Differential Revision: D59653809
    
    Reviewed By: sryap
    Josh Fromm authored and facebook-github-bot committed Jul 12, 2024
    Configuration menu
    Copy the full SHA
    d1f21e1 View commit details
    Browse the repository at this point in the history
  3. Optimize FBGEMM Triton MX4 Dequantize (pytorch#2837)

    Summary:
    Pull Request resolved: pytorch#2837
    
    We previously had to use python to unravel values from exponents and feed them to triton as two separate tensors. This introduced a lot of overhead as it introduced large copies.
    
    This diff does a bunch of fancy indexing to directly operate on a tensor with mixed elements and exponents. The result is that triton dequantize is now slightly faster than the cuda kernel. My hope is that this allows us to standardize on a single implementation.
    
    I think we could probably do something similar during quantize to get a significant speedup as well.
    
    ```
    INFO:root:input size: 1073741824 group size: 32
    INFO:root:Start to benchmark ...
    INFO:root:Start to benchmark ...
    input_size=1073741824 MX4 quantized time per iter: 7563us
    input_size=1073741824 MX4 dequantized time per iter: 2756us
    INFO:root:Start to benchmark ...
    INFO:root:Start to benchmark ...
    input_size=1073741824 MX4 triton quantized time per iter: 5110us
    input_size=1073741824 MX4 triton dequantized time per iter: 2417us
    INFO:root:Start to benchmark ...
    INFO:root:Start to benchmark ...
    input_size=1073741824 FP8 quantized time per iter: 6274us
    input_size=1073741824 FP8 dequantized time per iter: 4223us
    ```
    
    Reviewed By: sryap
    
    Differential Revision: D59661776
    jwfromm authored and facebook-github-bot committed Jul 12, 2024
    Configuration menu
    Copy the full SHA
    c680732 View commit details
    Browse the repository at this point in the history