Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tianxing/fa int8 #669

Open
wants to merge 27 commits into
base: main_perf
Choose a base branch
from
Open

Tianxing/fa int8 #669

wants to merge 27 commits into from

Conversation

Chi-Chu319
Copy link

@Chi-Chu319 Chi-Chu319 commented Nov 29, 2024

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Deplayed the SM and PK_Scale to obtain higher fp precision, to not over
shoot the int8 range during quantization. By default, the p scaling is
not enabled. If you provide a valid p_scale, p_descale, it would be
turned on. In practice, the p_scale is trained by the model.

Added int8 test.

Benchmark result

Index BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS Normal TFLOPS Int8 Performance Boost (%)
0 16 16.0 16.0 1024 1024 294.049318 375.476279 27.68%
1 8 16.0 16.0 2048 2048 360.917489 405.979504 12.49%
2 4 16.0 16.0 4096 4096 379.645407 415.134727 9.35%
3 2 16.0 16.0 8192 8192 390.533443 421.385511 7.90%
4 1 16.0 16.0 16384 16384 380.201436 413.302398 8.70%
5 2 48.0 48.0 1024 1024 243.064325 263.150690 8.27%
6 2 48.0 48.0 2048 1024 307.502534 340.843791 10.84%
7 2 48.0 48.0 4096 8192 363.350867 398.494640 9.67%
8 2 48.0 48.0 8192 4096 379.879379 423.701316 11.54%
9 2 48.0 48.0 16384 8192 389.565260 433.401303 11.26%
10 8 16.0 16.0 1989 15344 276.943421 385.945650 39.37%
11 4 16.0 16.0 4097 163 144.495573 105.871690 -26.74%
12 2 16.0 16.0 8122 2159 270.456464 404.755389 49.65%
13 1 16.0 16.0 16281 7 7.912893 4.540140 -42.62%
14 2 48.0 48.0 1021 1020 219.030497 263.565263 20.33%
15 2 48.0 48.0 2001 2048 326.161712 372.730572 14.28%
16 2 48.0 48.0 3996 9639 272.999622 382.946864 40.29%
17 2 48.0 48.0 8181 1021 266.868416 398.298176 49.27%

The quantization is per channel(HEAD).

Deplayed the SM and PK_Scale to obtain higher fp precision, to not over
shoot the int8 range during quantization. By default, the p scaling is
not enabled. If you provide a valid p_scale, p_descale, it would be
turned on. In practice, the p_scale is trained by the model.

Added int8 test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant