-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] De-quantization and int4 support for mixed dtypes GEMM #1122
Comments
as to scaling, there are many different scaling algorithms and scaling meta data can have different formats. which one should be supported first? as to int4, it is not hard to implement since we have int8 now. int8 PR essentially only touched one file under include/gemm/warp/. In its |
Hi @alexsamardzic, excited to hear there is further use of CUTLASS in PyTorch. We are working on a mixed input GEMM implementation for Hopper and hope to have scaling at some point (very similar to FT). We don't have plans to add scaling to the Ampere implementation by @manishucsd. We are happy to support you if you are willing to contribute this functionality to CUTLASS! |
For scaling, indeed there is number of variations that may be useful, but being able to multiply, element-wise, each row-vector of the product (i.e. the accumulator) with given vector of scaling factors would be good start. To clarify, for multiplying We're at the moment mostly interested in Ampere, so if you guys could give me some hints, I can try to add this type of scaling to CUTLASS. Please note that we're also interested to have this kind of functionality exposed through Any further hints on int4 support are welcome too. |
how much do you know about cutlass now? after you watched our ampere gtc talk given by Andrew kerr several years ago, we can have a 1:1 about scaling and int4. |
You can reuse the scaling logic from FT almost entirely. There is a warp_dequantizer you can use that will load data from smem to registers to match the input layout of hmma on Volta / Turing / Ampere. Note that FT scales inside the mainloop (before mma) as scaling in epilogue degraded the model output a lot. I think it is because the range of the activations are quite different if we don't scale. |
Thanks for the further hints! I do know low-level stuff about tensor operations, memory access optimization etc., but not much about CUTLASS internals (have a PR merged, but that was mainly following what @hwu36 told me to do). So let me see if I can understand what FT guys did, and try to come up with a PR here. |
Sounds great. I wrote the code in FT so feel free to ask any questions about it here. As Haicheng mentioned, we can also have a meeting if it is helpful. |
Catching up on the thread here... Thanks @alexsamardzic for your interest and efforts on mixed-input work. Enumerating the Requests on this Thread:(a) Scaling support for mixed-input (
(c) Support mixed-input with int4 ( Current Status Requests Clarification Request My immediate focus, in the coming weeks, would be to start targeting mixed-input on H100, but I am interested in all of the above. I am happy to help in design and any questions that you may have on mixed-input work in general and especially the PR #1084. |
Thanks @manishucsd. PyTorch uses |
Couple questions/comments:
|
correct
i don't think interleaved format will change anything here. we just need to load in a vector. if it is an interleaved B, scale data is already pre-processed too. @rhenry-nv ?
From performance perspective, fusion in the epilogue is always preferred. fusion in the mainloop hurts the performance in most cases. cutlass epilogue broadcast fusion needs to load an additional vector. You can take a look at https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @apuaaChen pay attention to #1120 . This example uses newly introduced EVT. You can also use the old way to do broadcast fusion as https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu |
Scale data layout is the same regardless of interleaving so you don't need to do anything special there |
@alexsamardzic, I see that you do comparison between cuBLAS vs. NVIDIA/CUTLASS vs. OpenAI/Triton vs. NVIDIA/FasterTransformer. Can you please confirm the following? A) You are running the following layouts and data types with four providers.
B) I see that you are running a bunch of matmul shapes. Do you autotune across various ThreadBlockShape, NumStages, and SplitK for NVIDIA/CUTLASS, OpenAI/Triton, and NVIDIA/FasterTransformer?
C) Can you please share your profiling numbers on matmul shape of |
A) Datatypes and layouts are as you mentioned, except that cuBLAS is running B) Nothing gets auto-tuned in my benchmarking, aside for Triton code generated by Torch compiler. Namely, there are two aspects of supporting mixed datatypes MM in PyTorch. The first one, that I'm working on at the moment, is for eager mode of execution, i.e. when a Python script is executed line-by-line. The second one would be for the case when the same script pre-compiled, by Torch compiler. Torch compiler primarily generates Triton code, but recently CUTLASS back-end is added, so that it's able to generate CUTLASS code too (when it encounters an operation in the Python script that is supported by CUTLASS), utilizing C) For given shapes, the first benchmark gives NVIDIA/FasterTransformer speedup over cuBLAS of about 0.9, while Triton code generated by Torch compiler has the same speedup (but that's deceivable, as Torch compiler may decide to insert a call to available pre-compiled, eager mode, code instead of generating Triton kernel(s), which it does in this case, so this is why "Triton performance" is reprorted the same as NVIDIA/FasterTransformer). The second benchmark gives NVIDIA/CUTLASS speedup of about 0.75 (the exact time reported on A100 for NVIDIA/CUTLASS is 1.165ms - but note that here it includes passing arguments from Python down to C++) - so it's an example of why I asked about NVIDIA/CUTLASS vs. NVIDIA/FasterTransformer performance. (For reference, I used warp shape 64x64x64 here). |
Thanks for the pointers, tried the "old" way and it is indeed easy to apply scaling. Am I correct that scaling matrix has to be in the same layout as output matrix? Is this the case for EVT too? If so, would the change needed to support say applying scaling factors given in column-major layout to the output in row-major layout (i.e. to have vector of scaling factors row broadcasted in this case) be alike to what I did with #951? Edit: OK, I think I was able to make this work, for EVT case, using |
In the meantime, I've updated mixed dtypes for PyTorch PR with support for de-quantization in an EVT epilogue (as well as support for adding bias and activation functions). If I get EVT stuff right, one could opt for example for F32 accumulator in case of F16 inputs, and then keep F32 as datatype for epilogue operations, where inputs to these operations get upcasted to F32 if needed, and eventually downcast to F16 for the final operation result. This way, the precision of de-quantization seems quite satisfactory in my experiments so far; of course, there is some performance penalty for upcasting/downcasting, but it is still better than doing de-quantization during MMA. Let me ask again: are EVT epilogues supported for sparse GEMM, and if not are there any plans to add support? |
In the concept, epilogue for sparse and dense are the same. you need to do some plumbing to connect sparse gemm and evt epilogue just in the same way as the dense one. |
also fusion in the epilogue is always preferred to mainloop fusion for the performance sake. |
Thanks! Also to mention: after some profiling and parameters tuning (primarily instruction/warp/threadblock shapes) the performance of CUTLASS version of code (mixed dtypes MM + de-quantization in the epilogue) seems about on par with the FT version. |
Thanks for this analysis. Can you please share the new results of this profiling (A few rows of csv : |
The shapes used in both of NVIDIA/FT and NVIDIA/CUTLASS cases are: instruction shape 16x8x16, warp shape 64x64x32 and threadblock shape 128x128x64. The benchmark is measuring latency of linear operator Here is a screenshot (it's color coded, so it may be easier to spot best performers) of benchmark results: Once again: please note that passing parameters from Python to C++ is calculated too in above latency numbers. Also, please ignore "Triton" column for our purpose - it's actually PyTorch compiler output, that is not always compiled to pure Triton code.) I can calculate GFLOPs if needed; for example, as far as numerical operations concerned, for
|
I think we might be able to optimize @manishucsd's version in CUTLASS to improve the performance with the canonical layout. The FT version does two transformations of the weights:
We can achieve the benefit from 2) without interleaving the columns. However, we will require a separate main loop (or a generalization of the current one) in CUTLASS 2. The core idea would be to have a GMEM to SMEM This adds some complications to the mainloop as well but I think it should be doable. @manishucsd / @alexsamardzic is this something you would be interested in collaborating on? |
@rhenry-nv , Certainly, let us chalk out what all is need to change in CUTLASS 2.0 to support 2. Here is my list (thinking out loud):
Is your CUTLASS 3.3 on similar lines? |
The one in v3.3 doesn't have this optimization, but I am hoping to have bandwidth to add it in v3.4. I will try to schedule a meeting so we can design it together and use consistent APIs in 2.x and 3.x for specifying the thread block shape. I am also certain there is more, but we can brainstorm when we meet. Does that sound good? |
yup! sounds good! |
Please let me know if you guys have any more specific instructions on what to do/change to:
I appreciate hints already provided above, and I honestly tried my best in the meantime to find how to do that, but to no avail. |
I'm working on |
you likely need to change the k dimension of threadblock and warp shape. you also need to change the alignment of int4 operand. |
I've updated my 4-bit mixed dtypes branch: alexsamardzic@f88a889 (I'm force pushing, so this commit contains all of my changes). Reverted temporarily to doing The problem I mentioned above is still there: when I hope that this may be the only remaining issue to have something that works, but at the moment I have no idea how to fix it. Apparently, |
can you paste me your device level template instantiation code? |
It's |
you need to change all k dimensions of different shapes to be 128 instead of 64 and then make it work. see this int4 unit test
|
That doesn't help, the problem is that the k-dimension of |
don't change the InstructionShape. it is still 16x8x16. warp shape k is 128. the numbers in |
for fp16 one, you may still need to use 64, but advance the warp iterator one more time and do the warp load one more time to get the next 64 data in the k dimension. |
I'm sorry if I wasn't clear enough in trying to explain the problem. I re-tried your suggestions, and changing mentioned values doesn't help to fix the build error reported. The problem is with |
you need 2 hmma.fp16 to handle 32 elements in k. |
hi @alexsamardzic , I talked with both @manishucsd and @rhenry-nv . I can elaborate what i said above in more details. First, let us take a look at A:f16 x B: s8. A stored as In the case of A:fp16 x B:s4. A still stored as |
Hello, may I ask a question? |
Hi @alexsamardzic, I have some cycles to look |
@manishucsd: The branch is here. Applied above suggestions by @hwu36, trying now to make it work for now just from the context of the new |
I just updated my branch, it seems to work properly now from the context of new
My changes are based on @hwu36 idea of skipping
But at least it all seems to work together. On the other side, I can see some related changes are made in the main in the meantime (I noticed ones in the |
Thanks @alexsamardzic for the progress on this. I will start looking into it soon and keep you posted. What exact changes you are referring to in the |
Yes, I meant on this change. |
In the meantime, I've addressed 1. and 2. from the list above, corresponding commits are pushed to my branch. @manishucsd: Please let me know if you haven't looked into my branch yet, so that I can squash the commits, and that you don't have to look into obsoleted stuff. |
Feel free to squash the commits. |
@apuaaChen: My changes to make sparse GEMM working with EVT are here, would you mind taking a look? |
This issue has been labeled |
Closing the issue as the discussion continued in the above mentioned PRs. |
(I've asked here, but I guess since PR is closed nobody is looking there any more, thus I'm opening this feature request.)
My question is: Are there any immediate plans for adding de-quantization (i.e. applying scaling factors along with mixed datatypes GEMM calculations), and also about adding int4 (two of them packed into an int8 value) support? Rationale: The int8 matrix, that is now supported by this PR, is typically coming from quantization, so it would be beneficial to have de-quantization supported. Also, quantization is typically used to save space in memory, so having int4 supported would be another step forward in this direction.
I'm asking as I'm working on adding support for mixed dtypes GEMM in PyTorch. My PRs, based on the mixed dtypes CUTLASS extensions from FasterTransformer project are here and here. The problem with these extensions is that they require reordering the elements of integer matrix, and also that they don't provide support for mixed dtypes GEMM in
cutlass_library
, that CUTLASS upstream now does. So, if instructed, I'm willing to help in adding these features (de-quantization would be a priority for me) into CUTLASS.The text was updated successfully, but these errors were encountered: