Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support MLA decode #551

Merged
merged 17 commits into from
Nov 2, 2024
Merged

feat: support MLA decode #551

merged 17 commits into from
Nov 2, 2024

Conversation

tsu-bin
Copy link
Contributor

@tsu-bin tsu-bin commented Oct 23, 2024

Hi, this PR implements MLA decode algorithm, I would love to hear your thoughts on this design and implementation.

The mystery Mat Absorb algorithm

In the DeepSeekV2 paper, there was no specific formulas for how to do param matrixes absorption, but it just vaguely said that

Fortunately, due to the associative law of matrix multiplication, we can absorb 𝑊𝑈𝐾 into 𝑊𝑈𝑄, and 𝑊𝑈𝑉 into 𝑊𝑂

I know there were also some discussion on this, but I still can't find convinced answer.
Here is my conclusion on this topic, Mat Absorb is only suitable for decode, do not use Mat Absorb for prefill, which means MLA decode and prefill should have different computation graph and different set of params, and Mat Absorb should merge param matrixes offline, materialize the merged param matrixes. You can find the two sets of Mat Absorb are two einsum ops in test_mla_decode_kernel.py.

# Now we merge W_UQ and W_UK (absorb W_UK into W_UQ)
# q~q_lora_rank  n~num_heads  d~qk_nope_head_dim  l~kv_lora_rank
self.W_UQ_UK = torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten(start_dim=1) # [1536, 65536]

# Merge W_UV and W_O (absorb W_UV into W_O)
# l~kv_lora_rank  n~num_heads  d~v_head_dim  h~hidden_size
self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten(start_dim=0, end_dim=1) # [65536, 5120]

I'm going to state my reason below. First let me depict the original MLA algorithm in computation graph (The final o_proj is omitted). It can be regarded as a 128 heads / (128+64) dim MHA algorithm.
image
And after Mat Absorb, MLA become a special 128 heads / (512+64) dim MQA algorithm, please note that the compressed_kv is used as both K and V directly without any projection. The detailed Mat Absorb algorithm can be found in test_mla_decode_kernel.py, in which DeepseekV2AttentionVanilla is the original DeepSeekV2 MLA inference implementation copied from huggingface and modified slightly, we take DeepseekV2AttentionVanilla as a reference to verify the correctness our Mat Absorb implementation. The DeepseekV2AttentionMatAbsorbDecode is our Mat Absorb implementation, it has two versions of inference function(run_proof_of_concept), one is implemented purely by torch, which can help you to make it clear how the Mat Absorb version of MLA inference works, and the other uses our new flashinfer MLA decode kernel, you can also take it as an usage example.
image

Now let's do some calculation to see the if Mat Absorb version is performant (for the sake of convenience, we call the original MLA algo as Vanilla version) .

# We calculate the number of float ops needed by the part of MLA computation graph,
# the input tensors are c_Q and cached k_pe and compressed_kv, the output tensor is the output hidden states.
# We omitted the calculation from input hidden states to c_Q and cached k_pe and compressed_kv, 
# because it's the same for both vanilla version and mat-absorb version.
def num_float_ops_vanilla(q_len, kv_len):
    return ( q_len*1536*(128*192) + # from c_Q to q_pe and q_nope, corresponding to q_b_proj
                kv_len * 512 * (128*(128+128)) + # from compressed_kv to k_nop and value_states, corresponding to kv_b_proj
                128 * (q_len*64*kv_len + q_len*128*kv_len + q_len*kv_len*128) + # 128 heads MHA
                q_len * (128*128) * 5120 ) # from MHA output to output hidden states, corresponding to o_proj
def mem_footprint_vanilla(q_len, kv_len):
    return ( q_len*1536 + 1536*(128*192) + 
                kv_len*512 + 512*(128*(128+128)) + 
                128 * ((q_len*64 + 64*kv_len) + (q_len*128 + 128*kv_len)) + 
                q_len * (128*128) + (128*128) * 5120 ) 
def num_float_ops_mat_absorb(q_len, kv_len):
    return ( q_len*1536*(128*64) + # from c_Q to q_pe, corresponding to W_QR
                q_len*1536*(128*512) + # from c_Q to q_nope, corresponding to W_UQUK
                128 * (q_len*64*kv_len + q_len*512*kv_len + q_len*kv_len*512) + # 128 heads MQA
                q_len * (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O
def mem_footprint_mat_absorb(q_len, kv_len):
    return ( q_len*1536 + 1536*(128*64) +
                1536*(128*512) +
                128 * (q_len*64 + q_len*512) + 1*(64*kv_len + 512*kv_len) +
                q_len * (128*512) + (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O
kv_len = 10000
print(f"prefill: num_float_ops mat_absorb vs vanilla ratio  ~ {num_float_ops_mat_absorb(kv_len, kv_len) / num_float_ops_vanilla(kv_len, kv_len)}"
print(f"prefill: mem_footprint mat_absorb vs vanilla ratio  ~ {mem_footprint_mat_absorb(kv_len, kv_len) / mem_footprint_vanilla(kv_len, kv_len)}")
print(f"decode: num_float_ops mat_absorb vs vanilla ratio  ~ {num_float_ops_mat_absorb(1, kv_len) / num_float_ops_vanilla(1, kv_len)}")
print(f"decode: mem_footprint mat_absorb vs vanilla ratio  ~ {mem_footprint_mat_absorb(1, kv_len) / mem_footprint_vanilla(1, kv_len)}")

The output is:

prefill: num_float_ops mat_absorb vs vanilla ratio  ~ 3.3602009088734754
prefill: mem_footprint mat_absorb vs vanilla ratio  ~ 2.2874373717252205
decode: num_float_ops mat_absorb vs vanilla ratio  ~ 0.010941137164898957
decode: mem_footprint mat_absorb vs vanilla ratio  ~ 1.167867978048944

So we can conclude from the result above, for decode case Mat Absorb version only use about 1% computation compared to Vanilla version, and the memory footprint is at the same level with Vanilla version, but for prefill case, both computation and memory footprint are much higher than Vanilla version, so there is no reason to use Mat Absorb for prefill, but it's worth a try for decode.

The kernel implementation design

The new MLA decode kernel actually follows the same design concept as the current decode kernel, also reuse much of existing code base, we add some helper functions, such as dim3_offset for better code readability.
The scheduling policy is also the same as the current one, we split task by kv-len dimension, and because the num_kv_heads is 1 now, we can't split num_kv_heads dimension across blocks now. There is one problem that the 128heads / (512+64)dim Q data is too large to fit into one SM's register file or even smem, which means we can't use only one SM/block to process one Q data, we have to tile the num_qo_heads dimension into gridDim.y, which can cause kv-cache data movement from gmem to smem multiple times, though this is inevitable.

Further improvement

  • Tensor-core version implementation, since current MLA models (DeepSeek-V2-Lite, DeepSeek-V2, MiniCPM3) all have large num_qo_heads, which is large enough to feed data into mma fragment, but in my opinion maybe this can have limited performance improvement, because consider the above analysis, the bottle neck is IO bandwidth not the computation intensity.
  • Load more Q head data per thread and per block. The more we load Q head data, the less block number is needed, the less kv data movement from gmem to smem is needed. We can add more q_nope_vec per thread, also we can use smem to store more q_nope_vec. I would love to hear inputs from others.

BTW, the new variable and function naming may be not follow current convention, I'm willing to change according to your advice.

@zhyncs zhyncs requested a review from yzh119 October 23, 2024 19:04
@zhyncs
Copy link
Member

zhyncs commented Oct 23, 2024

cc @ispobock @merrymercy

@ispobock
Copy link
Contributor

ispobock commented Oct 24, 2024

The append(extend) kernel also need to be considered w/ absorption. Because we read kv cache in the append kernel. The kv cache should be directly the compressed_kv.

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 24, 2024

The append(extend) kernel also need to be considered w/ absorption. Because we read kv cache in the append kernel. The kv cache should be directly the compressed_kv.

Maybe we can set a threshold value on the append token number based on some experience value to decide use w/ or w/o Mat Absorb. You can tweak the numbers of the above equations to see how the q_len, kv_len values affect the ratio values.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 25, 2024

Hi @tsu-bin thank you for contributing MLA, this is a great feature to have.

Do you have any performance number yet?

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 26, 2024

Hi @tsu-bin thank you for contributing MLA, this is a great feature to have.

Do you have any performance number yet?

hi @yzh119 I just added the benchmark code. Here is the result from my workspace (The cards were not fully vacant, so the result may vibrate a little). It proved what I said above, the bottle neck is IO bandwidth. The current scheduling design tiles num_qo_heads into gridDim.y and blockDim.y, blockDim.y is hard coded to empirical value 8, larger value will cause kernel launch failed because of insufficient register. So when num_qo_heads is 8, all kv cache will move from gmem to smem just once, when num_qo_heads is 16, there will be one additional kv cache movement from gmem to smem, when num_qo_heads is 24, total 3 times movements are needed, this is inevitable.
I will improve the kernel, that one thread can process multiple vec_t<float, vec_size_ckv> q_nope_vec and state_t<vec_size_ckv> st to maximise the reusage of kv cache in smem. I mean, we can tile num_qo_heads into gridDim.y blockDim.y and one thread's local variable, and load multiple q_nope_vec into thread's local variable or smem.
BTW, currently, DeepSeek-V2-Lite's num_qo_heads is 16, DeepSeek-V2 ~ 128, MiniCPM3 ~ 40.

### [0] NVIDIA GeForce RTX 4090
| page_size | batch_size | seqlen | num_qo_heads |    Read     |    Write    | Samples |  CPU Time  |  Noise  |  GPU Time  | Noise  | GlobalMem BW | BWUtil | Samples | Batch GPU  |
|-----------|------------|--------|--------------|-------------|-------------|---------|------------|---------|------------|--------|--------------|--------|---------|------------|
|        64 |         16 |   1024 |            8 |  18.142 MiB | 128.000 KiB |  13168x |  59.077 us | 496.60% |  37.999 us |  1.81% | 504.064 GB/s | 50.00% |  20559x |  26.903 us |
|        64 |        256 |   1024 |            8 | 290.268 MiB |   2.000 MiB |   1568x | 459.435 us | 100.20% | 387.508 us |  0.64% | 790.860 GB/s | 78.45% |   1569x | 621.003 us |
|        64 |         16 |  16384 |            8 | 288.156 MiB | 128.000 KiB |   1680x | 420.953 us |  64.27% | 383.623 us |  0.81% | 787.974 GB/s | 78.16% |   1681x | 564.941 us |
|        64 |        256 |  16384 |            8 |   4.502 GiB |   2.000 MiB |     98x |   5.130 ms |   0.14% |   5.126 ms |  0.11% | 943.623 GB/s | 93.60% |    103x |   5.070 ms |
|        64 |         16 |   1024 |           16 |  18.282 MiB | 256.000 KiB |  11696x |  70.112 us | 510.53% |  42.767 us |  2.13% | 454.385 GB/s | 45.07% |  13175x |  37.953 us |
|        64 |        256 |   1024 |           16 | 292.518 MiB |   4.000 MiB |    700x | 718.662 us |   0.63% | 714.300 us |  0.15% | 435.281 GB/s | 43.18% |    795x | 659.022 us |
|        64 |         16 |  16384 |           16 | 288.297 MiB | 256.000 KiB |   2448x | 639.263 us |  89.75% | 530.420 us |  1.11% | 570.422 GB/s | 56.58% |   2449x | 538.524 us |
|        64 |        256 |  16384 |           16 |   4.505 GiB |   4.000 MiB |     50x |  10.194 ms |   0.09% |  10.189 ms |  0.07% | 475.134 GB/s | 47.13% |     51x |  10.136 ms |
|        64 |         16 |   1024 |           32 |  18.564 MiB | 512.000 KiB |   6608x |  80.194 us |   7.34% |  75.770 us |  0.87% | 263.820 GB/s | 26.17% |   7239x |  85.765 us |
|        64 |        256 |   1024 |           32 | 297.018 MiB |   8.000 MiB |    369x |   1.911 ms |  80.21% |   1.358 ms |  0.41% | 235.565 GB/s | 23.37% |    401x |   1.671 ms |
|        64 |         16 |  16384 |           32 | 288.578 MiB | 512.000 KiB |   1408x |   1.203 ms |  54.55% |   1.054 ms |  0.68% | 287.539 GB/s | 28.52% |   1409x |   1.309 ms |
|        64 |        256 |  16384 |           32 |   4.509 GiB |   8.000 MiB |    544x |  24.394 ms |  30.96% |  24.101 ms | 28.84% | 201.230 GB/s | 19.96% |    545x |  25.029 ms |
|        64 |         16 |   1024 |           64 |  19.126 MiB |   1.000 MiB |   3664x | 240.412 us | 337.84% | 136.767 us |  2.15% | 154.304 GB/s | 15.31% |   3981x | 247.266 us |
|        64 |        256 |   1024 |           64 | 306.018 MiB |  16.000 MiB |   1152x |   3.161 ms |  45.16% |   2.947 ms | 25.35% | 114.565 GB/s | 11.36% |   1153x |   3.114 ms |
|        64 |         16 |  16384 |           64 | 289.141 MiB |   1.000 MiB |    461x |   2.113 ms |   0.88% |   2.107 ms |  0.50% | 144.368 GB/s | 14.32% |    462x |   2.108 ms |
|        64 |        256 |  16384 |           64 |   4.518 GiB |  16.000 MiB |    300x |  49.642 ms |  31.06% |  49.291 ms | 29.89% |  98.755 GB/s |  9.80% |    275x |  54.162 ms |
|        64 |         16 |   1024 |          128 |  20.251 MiB |   2.000 MiB |   1897x | 273.052 us |  13.59% | 263.698 us |  0.29% |  88.480 GB/s |  8.78% |   2015x | 292.347 us |
|        64 |        256 |   1024 |          128 | 324.018 MiB |  32.000 MiB |     80x |   9.885 ms |  31.27% |   8.637 ms | 20.73% |  43.223 GB/s |  4.29% |     91x |   5.522 ms |
|        64 |         16 |  16384 |          128 | 290.266 MiB |   2.000 MiB |   1056x |   5.033 ms |  43.37% |   4.798 ms | 31.07% |  63.871 GB/s |  6.34% |   1057x |   5.207 ms |
|        64 |        256 |  16384 |          128 |   4.535 GiB |  32.000 MiB |    141x | 106.693 ms |  27.41% | 106.178 ms | 26.91% |  46.181 GB/s |  4.58% |    142x |  98.496 ms |
### [1] NVIDIA GeForce RTX 4090
| page_size | batch_size | seqlen | num_qo_heads |    Read     |    Write    | Samples |  CPU Time  |  Noise  |  GPU Time  | Noise  | GlobalMem BW | BWUtil | Samples | Batch GPU  |
|-----------|------------|--------|--------------|-------------|-------------|---------|------------|---------|------------|--------|--------------|--------|---------|------------|
|        64 |         16 |   1024 |            8 |  18.142 MiB | 128.000 KiB |  11248x | 110.257 us | 251.77% |  44.468 us | 12.26% | 430.734 GB/s | 42.73% |  20446x |  24.455 us |
|        64 |        256 |   1024 |            8 | 290.268 MiB |   2.000 MiB |   1296x | 534.398 us |  42.42% | 387.465 us |  1.58% | 790.949 GB/s | 78.46% |   1538x | 400.293 us |
|        64 |         16 |  16384 |            8 | 288.156 MiB | 128.000 KiB |   1312x | 541.207 us |  48.81% | 381.861 us |  1.69% | 791.611 GB/s | 78.53% |   1489x | 395.982 us |
|        64 |        256 |  16384 |            8 |   4.502 GiB |   2.000 MiB |    688x |   6.364 ms |   5.50% |   6.096 ms |  2.86% | 793.435 GB/s | 78.71% |    689x |   6.098 ms |
|        64 |         16 |   1024 |           16 |  18.282 MiB | 256.000 KiB |  11248x | 129.133 us | 293.78% |  44.503 us |  2.28% | 436.657 GB/s | 43.31% |  11517x |  43.418 us |
|        64 |        256 |   1024 |           16 | 292.518 MiB |   4.000 MiB |   1024x | 856.933 us |  22.62% | 715.342 us |  0.88% | 434.647 GB/s | 43.12% |   1025x | 738.848 us |
|        64 |         16 |  16384 |           16 | 288.297 MiB | 256.000 KiB |    952x | 654.580 us |  29.90% | 525.261 us |  0.45% | 576.024 GB/s | 57.14% |    999x | 572.551 us |
|        64 |        256 |  16384 |           16 |   4.505 GiB |   4.000 MiB |    640x |  12.508 ms |   3.84% |  12.238 ms |  2.95% | 395.566 GB/s | 39.24% |    641x |  12.160 ms |
|        64 |         16 |   1024 |           32 |  18.564 MiB | 512.000 KiB |   6336x | 144.316 us | 136.74% |  78.986 us |  2.63% | 253.080 GB/s | 25.10% |   6762x |  77.674 us |
|        64 |        256 |   1024 |           32 | 297.018 MiB |   8.000 MiB |    369x |   1.623 ms |  20.81% |   1.356 ms |  0.40% | 235.863 GB/s | 23.40% |    399x |   1.570 ms |
|        64 |         16 |  16384 |           32 | 288.578 MiB | 512.000 KiB |    483x |   1.266 ms |  23.83% |   1.035 ms |  0.29% | 292.769 GB/s | 29.04% |    507x |   1.232 ms |
|        64 |        256 |  16384 |           32 |   4.509 GiB |   8.000 MiB |    593x |  24.920 ms |   2.44% |  24.632 ms |  2.01% | 196.895 GB/s | 19.53% |    594x |  24.626 ms |
|        64 |         16 |   1024 |           64 |  19.126 MiB |   1.000 MiB |   3504x | 267.030 us | 117.04% | 142.910 us |  2.82% | 147.672 GB/s | 14.65% |   3724x | 146.479 us |
|        64 |        256 |   1024 |           64 | 306.018 MiB |  16.000 MiB |   1184x |   3.190 ms |   9.88% |   3.010 ms |  5.56% | 112.182 GB/s | 11.13% |   1185x |   3.087 ms |
|        64 |         16 |  16384 |           64 | 289.141 MiB |   1.000 MiB |    243x |   2.322 ms |  13.51% |   2.060 ms |  0.17% | 147.721 GB/s | 14.65% |    255x |   2.500 ms |
|        64 |        256 |  16384 |           64 |   4.518 GiB |  16.000 MiB |    300x |  49.729 ms |   2.19% |  49.467 ms |  2.05% |  98.404 GB/s |  9.76% |    300x |  49.721 ms |
|        64 |         16 |   1024 |          128 |  20.251 MiB |   2.000 MiB |   1856x | 389.309 us |  60.45% | 271.146 us |  1.19% |  86.049 GB/s |  8.54% |   1925x | 304.575 us |
|        64 |        256 |   1024 |          128 | 324.018 MiB |  32.000 MiB |   1152x |   6.321 ms |   4.97% |   6.114 ms |  2.91% |  61.054 GB/s |  6.06% |   1153x |   6.165 ms |
|        64 |         16 |  16384 |          128 | 290.266 MiB |   2.000 MiB |    624x |   4.821 ms |   7.00% |   4.548 ms |  2.80% |  67.379 GB/s |  6.68% |    625x |   4.634 ms |
|        64 |        256 |  16384 |          128 |   4.535 GiB |  32.000 MiB |    150x |  99.795 ms |   2.27% |  99.519 ms |  2.20% |  49.271 GB/s |  4.89% |    150x |  99.623 ms |

@jason-huang03
Copy link

jason-huang03 commented Oct 28, 2024

hi @tsu-bin is it possible to get your email or wechat? Recently I am also very interested in the MLA kernel and I think we can work on it together. You can find my email in my profile.

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 28, 2024

I just added tile_size_qo_heads, now it's hard coded to value 2, ideally this change can boost performance to twice for cases which has num_qo_heads greater than 8. The actual result do reflect this expectation.
61CDEDC7-1DAF-4AFC-AB0E-5A31C559B8C8
When I increase tile_size_qo_heads to 4, the actual performance starts to deteriorate, the actual performance is worse than the previous version before this change. This indicates the performance bottle neck becomes to the computation efficiency, when tile_size_qo_heads >= 4. I think this suggests tensor-core implementation can further improve the performance.

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 28, 2024

hi @tsu-bin is it possible to get your email or wechat? Recently I am also very interested in the MLA kernel and I think we can work on it together. You can find my email in my profile.

hi @jason-huang03, I'm very glad to discuss with you, already sent mail to you. Hope we can work out the most optimized solution for MLA kernels.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 28, 2024

Hi @tsu-bin

When I increase tile_size_qo_heads to 4, the actual performance starts to deteriorate, the actual performance is worse than the previous version before this change. This indicates the performance bottle neck becomes to the computation efficiency, when tile_size_qo_heads >= 4. I think this suggests tensor-core implementation can further improve the performance.

I think that make sense, I can take over the work of implementing it with tensor cores (I'm refactoring the codebase with cutlass and I think it will be easier after that), but we can merge the cuda-cores implementation first.

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 29, 2024

Hi @yzh119 That's great, it's still a challenge to write mma code manually, are you planing to use CUTE to refactor current prefill implementation?
I will rebase my changes ASAP.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 29, 2024

are you planing to use CUTE to refactor current prefill implementation?

Yes I'm working on that.

I will rebase my changes ASAP.

Sounds great!

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 30, 2024

hi @yzh119 rebase is done, please note that there are still some recent features, such as 'improve plan performance by using non-blocking memcpy #547', that still need to be applied to new code from this PR.

BTW, when you are refactoring prefill code, maybe you can make some room for code reuse to ease the upcoming MLA prefill implementation.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 31, 2024

BTW, when you are refactoring prefill code, maybe you can make some room for code reuse to ease the upcoming MLA prefill implementation.

Sure, we are trying to unify all attention variants into data structures like this:
https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/attention/variants.cuh
The current form is only the beta version and I'll have some updates soon, and it'll be interesting to see how MLA fit into this vision.

include/flashinfer/utils.cuh Show resolved Hide resolved
include/flashinfer/utils.cuh Show resolved Hide resolved
include/flashinfer/utils.cuh Show resolved Hide resolved
output_mat_absorbed_use_torch = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False)
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=True)

cos_sim_use_torch = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch.reshape(-1), dim=0)
Copy link
Collaborator

@yzh119 yzh119 Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use cosine similarity here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually during my development of decode kernel, I tried hard to align the kernel implementation with the correct computation semantic, I found this metric is a good indicator to guide me if I was in the right direction to fix the discrepancy, the value changed from 0.1 -> 0.95 -> 0.9999.
But now I just tried the MSE as the metric, it value is rather large than I expected, it seems that the cosine similarity is a relaxed standard.
So there are still some discrepancy, I will look into this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @yzh119 I just updated the test case, added both MSE and WMAPE metrics. Below is the output from one run, it seems that the f32 to f16 conversion can cause some precision loss. I still can't find wrong implementation in decode kernel.

cos_use_torch_f32 = 1.0
wmape_use_torch_f32 = 9.26729843382549e-06
mse_use_torch_f32=0.0008153514936566353

cos_use_torch_f16 = 0.9997764825820923
wmape_use_torch_f16 = 0.016850485358918303
mse_use_torch_f16 = 3793.904296875

cos_use_flashinfer = 0.9999209642410278
wmape_use_flashinfer = 0.012483024544464835
mse_use_flashinfer = 1346.939453125

The MSE value is rather large, because the elements in the output tensor are at the level of thousands, you can try it by yourself.
Do you think WMAPE (corresponds to comparison of tensor's magnitude) and Cosine Similarity (corresponds to comparison of tensor's angle) together can be enough proof for algorithm correctness?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for confirmation, I think cosine similarity is okay in this case.
Let's merge this PR first and investigate the numerical issue later.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great contribution, thank you @tsu-bin !

@yzh119 yzh119 merged commit 5d454ed into flashinfer-ai:main Nov 2, 2024
@tsu-bin
Copy link
Contributor Author

tsu-bin commented Nov 2, 2024

Look forward to your CUTE refactor of prefill kernel, then we can continue to work on MLA prefill kernel. Hope soon the complete implementation can be ready for production.

@tsu-bin tsu-bin deleted the mla_decode_dev branch November 2, 2024 13:03
@fengyang95
Copy link

cc @ispobock @merrymercy

@zhyncs Hi,are there plans to port this to sglang in the future?

@fengyang95
Copy link

Look forward to your CUTE refactor of prefill kernel, then we can continue to work on MLA prefill kernel. Hope soon the complete implementation can be ready for production.

Hi @tsu-bin @yzh119 is there any news about the prefill kernel?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants