Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] De-quantization and int4 support for mixed dtypes GEMM #1122

Closed
alexsamardzic opened this issue Oct 3, 2023 · 55 comments
Closed

[FEA] De-quantization and int4 support for mixed dtypes GEMM #1122

alexsamardzic opened this issue Oct 3, 2023 · 55 comments
Labels
feature request New feature or request inactive-30d
Milestone

Comments

@alexsamardzic
Copy link
Contributor

(I've asked here, but I guess since PR is closed nobody is looking there any more, thus I'm opening this feature request.)

My question is: Are there any immediate plans for adding de-quantization (i.e. applying scaling factors along with mixed datatypes GEMM calculations), and also about adding int4 (two of them packed into an int8 value) support? Rationale: The int8 matrix, that is now supported by this PR, is typically coming from quantization, so it would be beneficial to have de-quantization supported. Also, quantization is typically used to save space in memory, so having int4 supported would be another step forward in this direction.

I'm asking as I'm working on adding support for mixed dtypes GEMM in PyTorch. My PRs, based on the mixed dtypes CUTLASS extensions from FasterTransformer project are here and here. The problem with these extensions is that they require reordering the elements of integer matrix, and also that they don't provide support for mixed dtypes GEMM in cutlass_library, that CUTLASS upstream now does. So, if instructed, I'm willing to help in adding these features (de-quantization would be a priority for me) into CUTLASS.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 3, 2023

@manishucsd @rhenry-nv

as to scaling, there are many different scaling algorithms and scaling meta data can have different formats. which one should be supported first?

as to int4, it is not hard to implement since we have int8 now. int8 PR essentially only touched one file under include/gemm/warp/. In its transform function, the simple one we could just upcast int4 to int8 first and then call the rest of int8->fp16 code. Note, int4->fp16 will hurt the mainloop a lot. So it is not going to help compute bound case a lot.

@mnicely mnicely added this to the Backlog milestone Oct 3, 2023
@mnicely
Copy link
Collaborator

mnicely commented Oct 3, 2023

Hi @alexsamardzic, excited to hear there is further use of CUTLASS in PyTorch. We are working on a mixed input GEMM implementation for Hopper and hope to have scaling at some point (very similar to FT). We don't have plans to add scaling to the Ampere implementation by @manishucsd. We are happy to support you if you are willing to contribute this functionality to CUTLASS!

@alexsamardzic
Copy link
Contributor Author

For scaling, indeed there is number of variations that may be useful, but being able to multiply, element-wise, each row-vector of the product (i.e. the accumulator) with given vector of scaling factors would be good start. To clarify, for multiplying m by k matrix with k by n matrix, vector of scaling factors would have n elements, so it's just like that bias vector is added element-wise to the rows of the result, here these rows would be multiplied element-wise with the vector of scaling factors. However, ideally all of the existing epilogues, like adding bias and/or using activation function, should be still applicable afterwards.

We're at the moment mostly interested in Ampere, so if you guys could give me some hints, I can try to add this type of scaling to CUTLASS. Please note that we're also interested to have this kind of functionality exposed through cutlass_library.

Any further hints on int4 support are welcome too.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 3, 2023

how much do you know about cutlass now?

after you watched our ampere gtc talk given by Andrew kerr several years ago, we can have a 1:1 about scaling and int4.

@rhenry-nv
Copy link

You can reuse the scaling logic from FT almost entirely. There is a warp_dequantizer you can use that will load data from smem to registers to match the input layout of hmma on Volta / Turing / Ampere.

Note that FT scales inside the mainloop (before mma) as scaling in epilogue degraded the model output a lot. I think it is because the range of the activations are quite different if we don't scale.

@alexsamardzic
Copy link
Contributor Author

Thanks for the further hints! I do know low-level stuff about tensor operations, memory access optimization etc., but not much about CUTLASS internals (have a PR merged, but that was mainly following what @hwu36 told me to do). So let me see if I can understand what FT guys did, and try to come up with a PR here.

@rhenry-nv
Copy link

Sounds great. I wrote the code in FT so feel free to ask any questions about it here. As Haicheng mentioned, we can also have a meeting if it is helpful.

@manishucsd
Copy link
Contributor

manishucsd commented Oct 4, 2023

Catching up on the thread here...

Thanks @alexsamardzic for your interest and efforts on mixed-input work.

Enumerating the Requests on this Thread:

(a) Scaling support for mixed-input (f16 * s8) by supplying a vector of f16xN for the s8 operandB.
(b)cutlass_library exposure of mixed-input kernels to enable

  1. profiling, and
  2. use in PyTorch (?)

(c) Support mixed-input with int4 (f16 * s4)
(d) Support for canonical layouts (TN/Row-Col) for the weight or integer operand without requiring to reorder in the Global Memory.

Current Status
NVIDIA/FasterTransformer provides (a) and (c), but not (b) and (d).
NVIDIA/CUTLASS provides (b) and (d), but not (a) and (c).

Requests
(a), (b), and (d) for f16*s8 in one place is P0.
Additionally, having (c), i.e. f16*s4, is P1.

Clarification Request
@alexsamardzic Does PyTorch use CUTLASS kernels from cutlass_library? The reason for (b) is (b.1) or (b.2) or both?

My immediate focus, in the coming weeks, would be to start targeting mixed-input on H100, but I am interested in all of the above. I am happy to help in design and any questions that you may have on mixed-input work in general and especially the PR #1084.

@alexsamardzic
Copy link
Contributor Author

Thanks @manishucsd. PyTorch uses cutlass_library for compiling models (for faster execution), so the reason for (b) is (b.2).

@alexsamardzic
Copy link
Contributor Author

Couple questions/comments:

  1. Mixed dtypes GEMM implementation based on CUTLASS extensions from FasterTransformer project is merged into PyTorch in the meantime. I also have a PR for PyTorch, with an alternative implementation, based on what's currently available in CUTLASS upstream (i.e. on PR Support for Mixed Input TensorOp  #1084 here). So I was able to do some benchmarking (benchmark scripts provided in a comment to mentioned PyTorch PR, admittedly it's not apples vs. apples, but instead linear operator for FasterTransformer based version vs. just MM for CUTLASS upstream based version), for F16 * S8 case, between these two implementations. In this benchmarking of mine, it seems that FasterTransformer based version is faster (I mean: closer to cuBLAS timings for the same operation), so: is there any direct comparison already available, and is my assumption correct that this may be expected as CUTLASS upstream version seems to be doing the re-shuffling of second operand elements into an arrangement appropriate for HMMA instruction itself, while FasterTransformer version expects that elements of this operand are pre-shuffled?

  2. I was looking into adding de-quantization support into CUTLASS upstream MM version. It seems to me that changes are to be made at all levels: at the warp level according to what MmaTensorOpDequantizer is doing in the FasterTransformer version, but also at the threadblock and kernel level, again alike to what FasterTransformer version is doing, to provide that scales vector is passed down to the threadblock level from the kernel level, and then to have it loaded into shared memory by threadblock level. So an implementation would practically mean copying related stuff from FasterTransformer project, and then making the changes to adjust to the fact that second matrix is expected in plain column-major layout, instead of interleaved column-major as with FasterTransformer. Any suggestions here?

  3. I understand that doing scaling through an epilogue was found to produce precision issues. Still, I'm asked to try this way first, as it seems it could be easier/quicker to implement. So, any interest to eventually have this kind of epilogue supported by CUTLASS, and also: is it possible at all, i.e. is there actually a way to pass an additional vector to the epilogue? Namely, I would still want to be able to add C, i.e. to do alpha * ((A @ B) * scale) + beta * C (here @ is for matrix multiplication, and * for elementwise multiplication)?

@hwu36
Copy link
Collaborator

hwu36 commented Oct 12, 2023

is my assumption correct that this may be expected as CUTLASS upstream version seems to be doing the re-shuffling of second operand elements into an arrangement appropriate for HMMA instruction itself, while FasterTransformer version expects that elements of this operand are pre-shuffled?

correct

So an implementation would practically mean copying related stuff from FasterTransformer project, and then making the changes to adjust to the fact that second matrix is expected in plain column-major layout, instead of interleaved column-major as with FasterTransformer. Any suggestions here?

i don't think interleaved format will change anything here. we just need to load in a vector. if it is an interleaved B, scale data is already pre-processed too. @rhenry-nv ?

i.e. is there actually a way to pass an additional vector to the epilogue?

From performance perspective, fusion in the epilogue is always preferred. fusion in the mainloop hurts the performance in most cases. cutlass epilogue broadcast fusion needs to load an additional vector. You can take a look at https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @apuaaChen pay attention to #1120 . This example uses newly introduced EVT. You can also use the old way to do broadcast fusion as https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu

@rhenry-nv
Copy link

if it is an interleaved B, scale data is already pre-processed too

Scale data layout is the same regardless of interleaving so you don't need to do anything special there

@manishucsd
Copy link
Contributor

manishucsd commented Oct 13, 2023

@alexsamardzic, I see that you do comparison between cuBLAS vs. NVIDIA/CUTLASS vs. OpenAI/Triton vs. NVIDIA/FasterTransformer.

Can you please confirm the following?

A) You are running the following layouts and data types with four providers.

  1. cuBLAS is running : row-col layout for F16 <= F16*F16 + F32 datatype
  2. NVIDIA/CUTLASS : row-col layout for F16 <= F16*S8 + F32 datatype
  3. OpenAI/Triton : row-col layout for F16 <= F16*S8 + F32 datatype
  4. NVIDIA/FasterTransformer : row-col_interleaved layout for F16 <= F16*S8 + F32 datatype

B) I see that you are running a bunch of matmul shapes. Do you autotune across various ThreadBlockShape, NumStages, and SplitK for NVIDIA/CUTLASS, OpenAI/Triton, and NVIDIA/FasterTransformer?

NVIDIA/CUTLASS PR #1132 adds new ThreadBlockShape for autotuning. The original PR had only two threadblock tile shapes which may not be sufficient for all the ThreadBlockShape.

C) Can you please share your profiling numbers on matmul shape of 3456x4096x8192 with ThreadBlockShape of 128x128x64 and 3 NumStages with all four providers?

@alexsamardzic
Copy link
Contributor Author

A) Datatypes and layouts are as you mentioned, except that cuBLAS is running row-row layout, and that C operand is F16 instead of F32 everywhere. But there is number of other caveats in my benchmarks. For example, there are actually two of them: The first benchmark is comparing cuBLAS with NVIDIA/FasterTransformer (also with Python code compiled to Triton by Torch compiler thrown in the mix, out of curiosity) for linear operator implementation, with scaling applied too. The second benchmark is comparing cuBLAS with NVIDIA/CUTLASS just for MM operation. But my intention with these benchmarks is primarily to verify that mixed datatypes is somewhat on par with cuBLAS for the same datatypes. My impression from the results of these benchmarks was that NVIDIA/CUTLASS speedup numbers vs. cuBLAS are somewhat lower than NVIDIA/FasterTransformer vs. cuBLAS, so I just asked to check have you maybe produced some numbers for NVIDIA/CUTLASS vs. NVIDIA/FasterTransformer by yourself. For my purpose, NVIDIA/CUTLASS approach is much preferred as it doesn't require for reordering of S8 matrix.

B) Nothing gets auto-tuned in my benchmarking, aside for Triton code generated by Torch compiler. Namely, there are two aspects of supporting mixed datatypes MM in PyTorch. The first one, that I'm working on at the moment, is for eager mode of execution, i.e. when a Python script is executed line-by-line. The second one would be for the case when the same script pre-compiled, by Torch compiler. Torch compiler primarily generates Triton code, but recently CUTLASS back-end is added, so that it's able to generate CUTLASS code too (when it encounters an operation in the Python script that is supported by CUTLASS), utilizing cutlass_library. In principle, I'll need to support both, but I'll come to the second aspect later (this is another reason to prefer NVIDIA/CUTLASS over NVIDIA/FasterTransformer, as with the former, mixed datatypes MM is already supported by cutlass_library). The Torch compiler is doing auto-tuning, so for CUTLASS back-end it will be able to try different shapes, different number of stages and so on. But for the code that I'm writing at the moment, for eager mode execution, I have either to hard-code these, or to implement some kind of simple heuristic for choosing ones. So far, I was not able to come up with such heuristic, and at the moment I'm just hard-coding some values, that I mostly got from running benchmarks like these on number of shape combinations.

C) For given shapes, the first benchmark gives NVIDIA/FasterTransformer speedup over cuBLAS of about 0.9, while Triton code generated by Torch compiler has the same speedup (but that's deceivable, as Torch compiler may decide to insert a call to available pre-compiled, eager mode, code instead of generating Triton kernel(s), which it does in this case, so this is why "Triton performance" is reprorted the same as NVIDIA/FasterTransformer). The second benchmark gives NVIDIA/CUTLASS speedup of about 0.75 (the exact time reported on A100 for NVIDIA/CUTLASS is 1.165ms - but note that here it includes passing arguments from Python down to C++) - so it's an example of why I asked about NVIDIA/CUTLASS vs. NVIDIA/FasterTransformer performance. (For reference, I used warp shape 64x64x64 here).

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Oct 13, 2023

From performance perspective, fusion in the epilogue is always preferred. fusion in the mainloop hurts the performance in most cases. cutlass epilogue broadcast fusion needs to load an additional vector. You can take a look at https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @apuaaChen pay attention to #1120 . This example uses newly introduced EVT. You can also use the old way to do broadcast fusion as https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu

Thanks for the pointers, tried the "old" way and it is indeed easy to apply scaling. Am I correct that scaling matrix has to be in the same layout as output matrix? Is this the case for EVT too? If so, would the change needed to support say applying scaling factors given in column-major layout to the output in row-major layout (i.e. to have vector of scaling factors row broadcasted in this case) be alike to what I did with #951?

Edit: OK, I think I was able to make this work, for EVT case, using VisitorColBroadcast. Is this EVT stuff supported for sparse GEMM?

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Oct 18, 2023

In the meantime, I've updated mixed dtypes for PyTorch PR with support for de-quantization in an EVT epilogue (as well as support for adding bias and activation functions). If I get EVT stuff right, one could opt for example for F32 accumulator in case of F16 inputs, and then keep F32 as datatype for epilogue operations, where inputs to these operations get upcasted to F32 if needed, and eventually downcast to F16 for the final operation result. This way, the precision of de-quantization seems quite satisfactory in my experiments so far; of course, there is some performance penalty for upcasting/downcasting, but it is still better than doing de-quantization during MMA.

Let me ask again: are EVT epilogues supported for sparse GEMM, and if not are there any plans to add support?

@hwu36
Copy link
Collaborator

hwu36 commented Oct 18, 2023

In the concept, epilogue for sparse and dense are the same. you need to do some plumbing to connect sparse gemm and evt epilogue just in the same way as the dense one.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 18, 2023

also fusion in the epilogue is always preferred to mainloop fusion for the performance sake.

@alexsamardzic
Copy link
Contributor Author

Thanks! Also to mention: after some profiling and parameters tuning (primarily instruction/warp/threadblock shapes) the performance of CUTLASS version of code (mixed dtypes MM + de-quantization in the epilogue) seems about on par with the FT version.

@manishucsd
Copy link
Contributor

Thanks! Also to mention: after some profiling and parameters tuning (primarily instruction/warp/threadblock shapes) the performance of CUTLASS version of code (mixed dtypes MM + de-quantization in the epilogue) seems about on par with the FT version.

Thanks for this analysis.

Can you please share the new results of this profiling (A few rows of csv : GEMM shape, Top Tile/Config (NVIDIA/FT), GFLOPs (NVIDIA/FT), Top Tile/Config (NVIDIA/CUTLASS), GFLOPs (NVIDIA/CUTLASS)?

@alexsamardzic
Copy link
Contributor Author

The shapes used in both of NVIDIA/FT and NVIDIA/CUTLASS cases are: instruction shape 16x8x16, warp shape 64x64x32 and threadblock shape 128x128x64. The benchmark is measuring latency of linear operator (input @ weight.T) * scale + bias. cuBLAS version is doing input * weight_scaled + bias where weight_scaled = weigh * scale.T is pre-calculated, outside of benchmark, while for FT version weight is arranged in required layout, also outside of benchmark; thus the comparison is actually still unfair to CUTLASS version. The input is of m x k shape, the weight is n x k, while scale and bias are 1 x n.

Here is a screenshot (it's color coded, so it may be easier to spot best performers) of benchmark results:

benchmark

Once again: please note that passing parameters from Python to C++ is calculated too in above latency numbers. Also, please ignore "Triton" column for our purpose - it's actually PyTorch compiler output, that is not always compiled to pure Triton code.)

I can calculate GFLOPs if needed; for example, as far as numerical operations concerned, for m, n, k = 2048, 2048, 2048 it should be:

  1. For cuBLAS version: (2048^2 * (2048 + 2047 + 1) / (89.0 * 10^(-6)) * 10^(-9) = 193032.24 GFLOPs.
  2. For FT version, I'm not sure how exactly to calculate in operations with scale tensor.
  3. For CUTLASS version: (2048^2 * (2048 + 2047 + 2) / (107.6 * 10^(-6)) * 10^(-9) = 159703.19 GFLOPs.

@rhenry-nv
Copy link

rhenry-nv commented Oct 19, 2023

I think we might be able to optimize @manishucsd's version in CUTLASS to improve the performance with the canonical layout. The FT version does two transformations of the weights:

  1. An interleave so we can use LDSM without shuffle. Manish's code is showing that this transformation may not be necessary.
  2. An interleave so when we issue LDG.128, we use all the data in the 128B cache line.

We can achieve the benefit from 2) without interleaving the columns. However, we will require a separate main loop (or a generalization of the current one) in CUTLASS 2. The core idea would be to have a GMEM to SMEM K-TILE of 128 for int8 data and 64 for FP16 data which would let us utilize cache lines better. In general, we just want to load 128B K-TILE regardless of the types.

This adds some complications to the mainloop as well but I think it should be doable. @manishucsd / @alexsamardzic is this something you would be interested in collaborating on?

@manishucsd
Copy link
Contributor

manishucsd commented Oct 19, 2023

@rhenry-nv , Certainly, let us chalk out what all is need to change in CUTLASS 2.0 to support 2.

Here is my list (thinking out loud):

  • ThreadblockShape in-terms of bytes instead of number of elements OR supply ThreadblockShapeK as tuple (TileK_a, TileK_b).
  • gemm-k-iterations will need to skip GMEM-to-SMEM for int8 on every other iteration.
    I am sure there is more, but let us talk more how design it and go about.

Is your CUTLASS 3.3 on similar lines?

@rhenry-nv
Copy link

The one in v3.3 doesn't have this optimization, but I am hoping to have bandwidth to add it in v3.4.

I will try to schedule a meeting so we can design it together and use consistent APIs in 2.x and 3.x for specifying the thread block shape.

I am also certain there is more, but we can brainstorm when we meet. Does that sound good?

@manishucsd
Copy link
Contributor

yup! sounds good!

@alexsamardzic
Copy link
Contributor Author

Please let me know if you guys have any more specific instructions on what to do/change to:

  1. Make sparse GEMM working with EVT epilogues.
  2. Add int4 support for mixed datatypes GEMM.

I appreciate hints already provided above, and I honestly tried my best in the meantime to find how to do that, but to no avail.

@alexsamardzic
Copy link
Contributor Author

I'm working on FragmentShuffler specialization for int4b_t, and then tried to create a test case, by copying SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8 and changing int8_t to int4b_t, but if I keep the rest of the test case the same, I'm getting division by 0 while building MmaTensorOpMultiplicandTileIterator specialized for 32-thread TensorOps; namely, InstructionShape::kContiguous is 16 here, while kLdsmOpOuter is 32, so LdsmShapeContiguous calculates to 0. Any suggestion about adjusting the test case for int4, am I right to assume that ColumnMajorTensorOpMultiplicandCrosswise would do for int4 too?

@hwu36
Copy link
Collaborator

hwu36 commented Oct 25, 2023

you likely need to change the k dimension of threadblock and warp shape. you also need to change the alignment of int4 operand.

@alexsamardzic
Copy link
Contributor Author

I've updated my 4-bit mixed dtypes branch: alexsamardzic@f88a889 (I'm force pushing, so this commit contains all of my changes). Reverted temporarily to doing S4 to F16 numeric conversion in a loop, I know how to make it faster, but I just want to connect the dots first. The conversion, and also the new S4 fragment shuffler, seem to work - or at least the outputs of these operations, when printed, are as expected for given inputs.

The problem I mentioned above is still there: when MmaTensorOpMultiplicandTileIterator instantiated, it turns out that InstructionShape::kContiguous is 16, which is less than kLdsmOpOuter that is 32 (this one is equal to TensorOpMultiplicand::kAccessSize that, as TensorOpMultiplicand is written for 128b access and S4 is 4b, comes as 32), and then LdsmShapeContiguous evaulates to 0. Thus, changing threadblock and warp shape doesn't help, and I don't think the alignment of S4 is problem either. I've made a "fix" (search for FIXME in my changes), so that I'm able to compile and run the test, but of course the test results come out wrong.

I hope that this may be the only remaining issue to have something that works, but at the moment I have no idea how to fix it. Apparently, InstructionShape::kContiguous cannot be increased and the whole problem actually comes from the fact that we're accessing smaller elements than ones that will be actually used for the multiplication, that these pieces of code are not prepared for (and S8/U8 case did not encountered this problem kind of by accident). In any case, as mentioned above: any suggestion on what to do here would be much appreciated.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 27, 2023

can you paste me your device level template instantiation code?

@alexsamardzic
Copy link
Contributor Author

It's struct Testbed, in test/unit/gemm/warp/testbed.h; and it's used from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4 test, in gemm_mixed_input_sm80.cu in the same directory. So basically, to see the mentioned build problem: checkout my commit, and then in include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h remove the line after FIXME comment, un-comment next line, and then do make cutlass_test_unit_gemm_warp.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 27, 2023

you need to change all k dimensions of different shapes to be 128 instead of 64 and then make it work.

see this int4 unit test


TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_16x8x64) {
  using Shape = cutlass::gemm::GemmShape<64, 64, 128>;
  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
  using Element = cutlass::int4b_t;
  using ElementC = int;
  using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
      cutlass::sizeof_bits<Element>::value, 128>;
  using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
      cutlass::sizeof_bits<Element>::value, 128>;

  using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
      Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
      cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type;

  test::gemm::warp::Testbed<MmaTensorOp,
                            cutlass::gemm::GemmShape<128, 128, 128> >()
      .run();
}

@alexsamardzic
Copy link
Contributor Author

That doesn't help, the problem is that the k-dimension of InstructionShape cannot be increased in my case, and this is exactly what causes what I tried to describe in my previous comment.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 28, 2023

don't change the InstructionShape. it is still 16x8x16. warp shape k is 128. the numbers in RowMajorTensorOpMultiplicandCrosswise and ColumnMajorTensorOpMultiplicandCrosswise are also 128. you may need to change the iterator of fp16 one to make it work.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 28, 2023

for fp16 one, you may still need to use 64, but advance the warp iterator one more time and do the warp load one more time to get the next 64 data in the k dimension.

@alexsamardzic
Copy link
Contributor Author

I'm sorry if I wasn't clear enough in trying to explain the problem. I re-tried your suggestions, and changing mentioned values doesn't help to fix the build error reported. The problem is with MmaTensorOpMultiplicandTileIterator instance specialized for TensorOpMultiplicandCrosswise layout, in include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h (the code for this iterator class starts around line 1300 of this file). The layout mentioned assumes 128-bit memory access, so for 4-bit elements, it could access to 32 elements simultaneously, but the Policy of this iterator is just not written to handle the case when this number is greater than the "inner" dimension of InstructionShape. For example, for mixed F16/S4 case, the InstructionShape has to be 16x8x16 as this is what is supported for F16 GEMM (that will be actually performed after up-casting S4 to F16). But then, as 32 elements (that one memory access will fetch) is more than 16 elements needed along "contiguous" dimension of 16x8 fragment (in case S4 matrix is matrix B for GEMM), there occurs a division by zero during compile-time calculations within mentioned Policy structure, and the iterator in general is not usable. Thus, at the moment, I'm trying to understand internals of this iterator class, in order to eventually make the changes needed to accommodate this particular case.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 30, 2023

you need 2 hmma.fp16 to handle 32 elements in k.

@hwu36
Copy link
Collaborator

hwu36 commented Oct 30, 2023

hi @alexsamardzic , I talked with both @manishucsd and @rhenry-nv . I can elaborate what i said above in more details.

First, let us take a look at A:f16 x B: s8. A stored as RowMajorTensorOpMultiplicandCrosswise<16, 64> and B stores as ColumnMajorTensorOpMultiplicandCrosswise<8, 64>. Suppose the warp tile size is 64x64x64, every time we do a warp_iterator_A load, we load 64(m) x 16(k) fp16 data. Every time we do a warp_iterator_B load, we load 64(n) x 16(k) int8. We extend int8 data to fp16 and do one 64x64x16 warp level mma.

In the case of A:fp16 x B:s4. A still stored as RowMajorTensorOpMultiplicandCrosswise<16, 64> and B stores as ColumnMajorTensorOpMultiplicandCrosswise<4, 64>. Still suppose the warp tile size is 64x64x64, every time we do a warp_iterator_A load, we still load 64(m) x 16(k) fp16 data. Every time we do a warp_iterator_B load, we have to at least load 64(n) x 32(k) int4b. To make this happen, you can give B a pseudo instruction shape 16x8x32 just like @rhenry-nv did in his FT version (https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h#L84, https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h#L185, and https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h#L93). However, still we have more B data than A data in K dimension. So we can either let A do one more warp load (more register pressure, higher ILP) or let B skip the next warp load (less register pressure, but lower ILP). One more thing to note, since B has 32 consecutive int4b data and one mma only needs 16 of them, the shuffle algorithm needs to move the data to the right place.

@zwshan
Copy link

zwshan commented Oct 31, 2023

The shapes used in both of NVIDIA/FT and NVIDIA/CUTLASS cases are: instruction shape 16x8x16, warp shape 64x64x32 and threadblock shape 128x128x64. The benchmark is measuring latency of linear operator (input @ weight.T) * scale + bias. cuBLAS version is doing input * weight_scaled + bias where weight_scaled = weigh * scale.T is pre-calculated, outside of benchmark, while for FT version weight is arranged in required layout, also outside of benchmark; thus the comparison is actually still unfair to CUTLASS version. The input is of m x k shape, the weight is n x k, while scale and bias are 1 x n.

Here is a screenshot (it's color coded, so it may be easier to spot best performers) of benchmark results:

benchmark

Once again: please note that passing parameters from Python to C++ is calculated too in above latency numbers. Also, please ignore "Triton" column for our purpose - it's actually PyTorch compiler output, that is not always compiled to pure Triton code.)

I can calculate GFLOPs if needed; for example, as far as numerical operations concerned, for m, n, k = 2048, 2048, 2048 it should be:

  1. For cuBLAS version: (2048^2 * (2048 + 2047 + 1) / (89.0 * 10^(-6)) * 10^(-9) = 193032.24 GFLOPs.
  2. For FT version, I'm not sure how exactly to calculate in operations with scale tensor.
  3. For CUTLASS version: (2048^2 * (2048 + 2047 + 2) / (107.6 * 10^(-6)) * 10^(-9) = 159703.19 GFLOPs.

Hello, may I ask a question?
from this picture cublas(fp16 * fp16) is faster than FT(fp16 * int8), so why we need FT(fp16 * int8)?

@alexsamardzic
Copy link
Contributor Author

Thanks @hwu36, that makes it clear, I'll try it this way.

@zwshan: It saves memory; think LLMs: if you quantize weights, your model could be larger, and still fit in the memory.

@manishucsd
Copy link
Contributor

manishucsd commented Nov 1, 2023

Hi @alexsamardzic, I have some cycles to look F16 * S4. Do you have a branch you can share where we can collaborate on?

@alexsamardzic
Copy link
Contributor Author

@manishucsd: The branch is here. Applied above suggestions by @hwu36, trying now to make it work for now just from the context of the new SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4 test case, will see how to generalize later.

@alexsamardzic
Copy link
Contributor Author

I just updated my branch, it seems to work properly now from the context of new SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4 test case. To try it, checkout the branch and then:

 make cutlass_test_unit_gemm_warp
./test/unit/gemm/warp/cutlass_test_unit_gemm_warp --gtest_filter="SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i4.*"

My changes are based on @hwu36 idea of skipping S4 loads each other time. Pretty much all aspects of these changes are to be improved:

  1. S4 to F16 conversion is to be vectorized.
  2. New fragment shuffler code is to be rewritten as it's very clunky and inefficient.
  3. Skipping loads is somehow to be included at the threadblock level (at the moment, it's a kludge in the testbed code).

But at least it all seems to work together.

On the other side, I can see some related changes are made in the main in the meantime (I noticed ones in the transform() method of MmaMixedInputTensorOp class). @manishucsd: Any update on this, and if you think it would be still worthwhile, would you mind checking my branch and providing feedback?

@manishucsd
Copy link
Contributor

Thanks @alexsamardzic for the progress on this. I will start looking into it soon and keep you posted.

What exact changes you are referring to in the transform(), is it the splitting of operandA transform() into two parts? (@hwu36)

@alexsamardzic
Copy link
Contributor Author

Yes, I meant on this change.

@alexsamardzic
Copy link
Contributor Author

In the meantime, I've addressed 1. and 2. from the list above, corresponding commits are pushed to my branch. @manishucsd: Please let me know if you haven't looked into my branch yet, so that I can squash the commits, and that you don't have to look into obsoleted stuff.

@manishucsd
Copy link
Contributor

manishucsd commented Nov 8, 2023

Feel free to squash the commits.

@alexsamardzic
Copy link
Contributor Author

@apuaaChen: My changes to make sparse GEMM working with EVT are here, would you mind taking a look?

@alexsamardzic
Copy link
Contributor Author

I've created PRs for easier review:

  1. Here is PR for 4-bit mixed dtypes support.
  2. Here is PR for EVT epilogue support for sparse GEMM.

@mnicely mnicely modified the milestones: Backlog, CUTLASS 3.4 Dec 5, 2023
Copy link

github-actions bot commented Jan 4, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

@alexsamardzic
Copy link
Contributor Author

Closing the issue as the discussion continued in the above mentioned PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request inactive-30d
Projects
None yet
Development

No branches or pull requests

7 participants