Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and FlashAttention2 #10206

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

jeffbolznv
Copy link
Collaborator

@jeffbolznv jeffbolznv commented Nov 7, 2024

This change adds support for VK_NV_cooperative_matrix2 (https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VK_NV_cooperative_matrix2, https://github.com/KhronosGroup/GLSL/blob/main/extensions/nv/GLSL_NV_cooperative_matrix2.txt) to the Vulkan backend. This is a recent Vulkan extension supported by NVIDIA drivers that enables matrix multiplies using the tensor cores, while being easier to use and supporting more operations than VK_KHR_cooperative_matrix.

While this PR is code complete and passes testing, it is a Draft for a while because the build system relies on the Vulkan SDK and the tooling for this extension will land in the next Vulkan SDK release. In the meantime, if you're interested to try this out locally you can clone https://github.com/KhronosGroup/Vulkan-Headers and grab the latest CI build from https://github.com/google/shaderc and set the cmake variables Vulkan_INCLUDE_DIR to point to Vulkan-Headers/include and Vulkan_GLSLC_EXECUTABLE to point to glslc. You'll need the most recent Vulkan beta driver from https://developer.nvidia.com/vulkan-driver. (Note, there is one more driver optimization forthcoming, but the results you'll see with this driver should be close).

The two main additions in this change are a coopmat2 mul_mat shader and a coopmat2 FlashAttention2 shader. The mul_mat shader supports normal matrix multiples and mixture of experts, and supports a variety of quantization formats using the "decode" callback functionality in coopMatLoadTensorNV. The decode callback functions are in dequant_funcs_cm2.comp and decode one element at a time. The FlashAttention2 shader also supports quantization formats and could theoretically use different formats for K and V, but the compilation cost for supporting all those combinations was too high and I don't know if this is ever used in practice. Note that the mul_mat approach to quantization formats is analogous to the existing Vulkan mul_mat shader in that it converts to fp16 before multiplying, whereas I believe the CUDA path converts to int8 and applies the scale/bias per-tile.

I've also done optimizations of some other shaders, including mul_mat_vec, split_k_reduce, and adding a "linear vec4 copy" shader. With the much higher cooperative matrix2 perf for mul_mat, these other state buckets become relatively more expensive. I'll split these out into smaller changes to review separately. But I wanted to include them all here for context and for perf measurements.

[Edit: Previous performance comparisons against existing Vulkan path were broken. The comparison between Vulkan w/coopmat2 and CUDA I believe was accurate.]

The coopmat2 path helps significantly with prompt processing. FA helps a little bit in token gen, but only a few percent here and there.

rtx4070pp512-fixed
rtx4070tg128-fixed
rtx6000pp512-fixed
rtx6000tg128-fixed

@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Nov 7, 2024
@0cc4m
Copy link
Collaborator

0cc4m commented Nov 8, 2024

Wow, this looks really impressive. I'll take a closer look later.

In the interest of compatibility, how much work do you think it is to also add a VK_KHR_cooperative_matrix codepath, for AMD RDNA3 and maybe Intel ARC? Obviously not something you have to do, but I'd have to try it sooner or later.

@0cc4m 0cc4m self-requested a review November 8, 2024 17:33
@jeffbolznv
Copy link
Collaborator Author

how much work do you think it is to also add a VK_KHR_cooperative_matrix codepath

It may not be too bad in the mul_mm shader since you already have the code to dequantize and copy to shared memory. Porting the flash attention shader would be more involved, both because there's no existing scalar shader to start from, and also because it does reductions and needs to clear padding elements, and there's no way to do that other than to dump the matrix out to shared memory.

@lin72h
Copy link

lin72h commented Nov 8, 2024

Very impressive. Just to let you know, Mesa's amdgpu driver is implementing a polyfill for pre-RDNA3 generation, so VK_KHR_cooperative_matrix optimization could benefit all AMD GPUs running Linux.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 9, 2024

Very impressive. Just to let you know, Mesa's amdgpu driver is implementing a polyfill for pre-RDNA3 generation, so VK_KHR_cooperative_matrix optimization could benefit all AMD GPUs running Linux.

Probably not, emulating the matrix operations will likely just be slower than a non-coopmat codepath. You can see an example of that at the bottom of the merge request.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 11, 2024

There's more parts to this that could be upstreamed separately, before the driver and Vulkan support for the coopmat2 extension lands. I'm thinking of the vector copy shader and the matrix vector multiplication shader improvements.

Additionally, if we upstream the f16acc and f32acc switch earlier we could adapt my matrix multiplication shader to it. I had previously hardcoded the accumulator to float due to precision issues.

In the meantime I'll look into a VK_KHR_cooperative_matrix implementation in the regular matmul shader.

@jeffbolznv
Copy link
Collaborator Author

I'm currently working on the optimized copy shader and will make a PR soon. I had trouble reproducing the gains from the mat-vec mul shader in isolation, but I think it may have only really benefited Q8_0. I'll try that again soon.

If you're going to look into the KHR path, I did a very basic prototype of it recently that you could use as a starting point: jeffbolznv@3416010. The first real bit of nastiness I ran into is needing to bound-check the store (see the comment).

@netrunnereve
Copy link
Collaborator

@jeffbolznv do you happen to know approximately how much of the tg improvement is due to tensor cores and how much is due to flash attention? Dedicated matrix matrix multipliers probably are going to help a lot with prompt processing but the matrix vector multiplications for interence are limited by memory bandwidth in most cases. From the graphs it looks like the original Vulkan implementation is severely compute bound.

@jeffbolznv
Copy link
Collaborator Author

Oh no, I'm afraid there was a mistake in our testing methodology where we had -fa 1 set for the existing Vulkan path, and it causes the FA nodes to quietly fallback to the CPU. There's little gain from the cooperative matrix 2 path for token generation (likely a few percent from FA). Still a good gain for prompt processing, but those results are also heavily skewed. I'm going to edit the first comment to retract them and I'll try to update with more accurate results in the next day or so. My apologies, everybody.

@jeffbolznv
Copy link
Collaborator Author

I've replaced the RTX 4070 results with something I think is more correct. Very little gain for tg, still large gain for pp. Will try to get updated RTX 6000 results tomorrow.

@ggerganov
Copy link
Owner

Very little gain for tg, still large gain for pp. Will try to get updated RTX 6000 results tomorrow.

FA for TG starts to make significant difference only at large contexts so this is expected for the tg128 test. You can use the llama-batched-bench tool to estimate the TG speedup at larger contexts (see the OP in #10171 for the commands).

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 13, 2024

Does someone have a theory for the DeepSeek Coder V2 Lite outlier in the benchmarks? Why is Vulkan significantly outperforming CUDA there, in tg even without the tensor cores. Some issue with the MUL_MAT_ID implementation?

@slaren
Copy link
Collaborator

slaren commented Nov 13, 2024

Possibly because the expert selection is done on the CPU.

@netrunnereve
Copy link
Collaborator

So I tried out the mul_mat_vec.comp changes only and got 6% faster inference speed with Q4_0. Pretty much all the improvements came from the barrier free subgroup add, and the additional unrolling and optimizations didn't seem to help much. Interestingly enough I got a negligible difference in performance when I tried to implement the subgroup add in the separate K-quant mat mul shaders.

BTW our mat vec shaders have a subgroup size of 32 which is fine for Nvidia and new AMD cards but bad for my old W8100 😏.

@jeffbolznv
Copy link
Collaborator Author

I'm working on splitting out the mul_mat_vec changes and making some additional optimizations, I hope to make a PR for that tomorrow. I haven't actually seen much gain from the subgroupAdd, and using that if the subgroup size is not equal to 32 is tricky, so I wasn't planning to leave that in. Was it on W8100 that you saw it helped?

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 14, 2024

I haven't used the subgroup operations so far cause they caused driver crashes on Intel. But I'll test that again once you make the PR. If that problem is still around we might have to make them optional.

GCN cards are definitely relevant, as for many of them Vulkan is the last API they have left. But also modern RDNA defaults to subgroups of size 64, though you can manually reduce that to 32 using VK_EXT_subgroup_size_control. Currently that is not done. I would like to support subgroups of 32 and 64 where possible.

@sorasoras
Copy link

sorasoras commented Nov 14, 2024

I haven't used the subgroup operations so far cause they caused driver crashes on Intel. But I'll test that again once you make the PR. If that problem is still around we might have to make them optional.

GCN cards are definitely relevant, as for many of them Vulkan is the last API they have left. But also modern RDNA defaults to subgroups of size 64, though you can manually reduce that to 32 using VK_EXT_subgroup_size_control. Currently that is not done. I would like to support subgroups of 32 and 64 where possible.

I wonder what kind of perf would be if rdna3 run at dual issue 32wave vopd.

f1c6cf47-ab88-47b7-a377-ef733a1b6c4b_670x429.png

@netrunnereve
Copy link
Collaborator

netrunnereve commented Nov 14, 2024

I haven't actually seen much gain from the subgroupAdd, and using that if the subgroup size is not equal to 32 is tricky, so I wasn't planning to leave that in. Was it on W8100 that you saw it helped?

Yep I got a 6% improvement on my W8100 by replacing the final barrier sum with something like this. I guess it might be due to the fact that my card is slower at thread synchronization and gets hung up by the two barriers, but I don't know.

    // sum up partial sums and write back result
    tmp = subgroupAdd(tmp);
    if (tid == 0) {
        data_d[d_offset + row] = D_TYPE(tmp);
    }

I haven't used the subgroup operations so far cause they caused driver crashes on Intel. But I'll test that again once you make the PR. If that problem is still around we might have to make them optional.

Honestly I don't think it's necessary to invest a lot of effort on a 6% gain for old quant types, though it's possible that there's a larger difference on certain GPU models. What I'll probably look into instead is the 64 subgroup size as my card is potentially only using half its capabilities with 32 and that should hopefully have a big improvement. The shaders look pretty simple to hack on as long as I don't need to touch ggml-vulkan.cpp.

@JohannesGaessler
Copy link
Collaborator

The Vulkan performance vs. CUDA with this PR is definitely much better than I would have expected would be possible using a more generic API. I'll benchmark the performance myself once the feature becomes available via package managers.

The decode callback functions are in dequant_funcs_cm2.comp and decode one element at a time.

This is an on-the-fly dequantization to FP16 into SRAM, right? Context: llama.cpp training support is one of my next goals and I think memory use will be the biggest bottleneck. So good performance for small batch sizes would be desirable and dequantization into VRAM would be too slow I think.

The FlashAttention2 shader also supports quantization formats and could theoretically use different formats for K and V, but the compilation cost for supporting all those combinations was too high and I don't know if this is ever used in practice.

The K cache seems to need more precision than the V cache so an asymmetric setup could make sense for users that like fiddling.

Note that the mul_mat approach to quantization formats is analogous to the existing Vulkan mul_mat shader in that it converts to fp16 before multiplying, whereas I believe the CUDA path converts to int8 and applies the scale/bias per-tile.

That is correct, the CUDA code uses either __dp4a or int8 tensor cores. I did this under the expectation that I would be able to get better throughput than FP16 cuBLAS GEMM but for the kernel itself (and ignoring any auxiliary kernels) I was unfortunately not able to get better performance. I think the reason is the comparatively small block size of 16/32 of the current GGML quantization formats. I want to at some point try making a format specifically for directly training as quantized where I think it would make sense to use larger blocks for better performance.

FA helps a little bit in token gen, but only a few percent here and there.

Try setting -n 8192, you need a large enough context for the results to be sensitive to the performance of the attention.

@jeffbolznv
Copy link
Collaborator Author

I've rebased this and it's a bit more readable now. I still want to split out the types.comp changes, will do after #10387 lands.

@jeffbolznv
Copy link
Collaborator Author

The decode callback functions are in dequant_funcs_cm2.comp and decode one element at a time.

This is an on-the-fly dequantization to FP16 into SRAM, right?

Yes, the intent is that the compiler should be staging the matrices through shared memory. That may not happen in all cases, depends on the implementation, but that's the goal.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 29, 2024

If you're going to look into the KHR path, I did a very basic prototype of it recently that you could use as a starting point: jeffbolznv@3416010. The first real bit of nastiness I ran into is needing to bound-check the store (see the comment).

Thank you @jeffbolznv. I also had a basic non-functional version I could start out with. I compared them and built something that looks mostly correct to me, but I can't seem to get correct results and I'm running out of ideas what the issue is.

I've uploaded the current state to https://github.com/ggerganov/llama.cpp/tree/0cc4m/vulkan-coopmat / f5cae8d. If built with GGML_VK_RUN_TESTS it runs some small basic matrix multiplications and checks the results.

It seems to be some issue with loading the coopmats, but it's very hard to debug. I think at least the partial store is correct. The load does not seem to work as I expect, since I get the results I want only when I stage the load from buf_a through coopmat_stage shared memory and then into the coopmat, not if I load directly from buf_a. Some of what I tried is commented out in the file.

I'm using an RTX 3090 to test it. Let me know if you see something. This is kinda off-topic for this PR though, I don't mean to hijack it for something else.

@jeffbolznv
Copy link
Collaborator Author

I'm using an RTX 3090 to test it. Let me know if you see something. This is kinda off-topic for this PR though, I don't mean to hijack it for something else.

Looks like each warp is only computing one tile of the result (WMxWN = 16x16), which doesn't line up with the wgdenoms/workgroupsize it's using. I think that at least explains the first failing test (32x32x16).

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 29, 2024

Looks like each warp is only computing one tile of the result (WMxWN = 16x16), which doesn't line up with the wgdenoms/workgroupsize it's using. I think that at least explains the first failing test (32x32x16).

I haven't looked into that yet because my assumption was that it should work for matrices small enough, like the 4x4x4 test I was using. Only one coopmat multiplication is required for that.

@jeffbolznv
Copy link
Collaborator Author

Strange, 4x4x4 is passing for me, the first failure I saw was in 32x32x16.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 29, 2024

Strange, 4x4x4 is passing for me, the first failure I saw was in 32x32x16.

Okay, that's odd. Implementation difference maybe? You're using Ada, right, not Ampere?

@jeffbolznv
Copy link
Collaborator Author

Okay, that's odd. Implementation difference maybe?

I wonder if this test doesn't have the stride a multiple of 16B. That's required by the spec, but I think recent drivers with the coopmat2 implementation made the driver more forgiving about this.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 29, 2024

Okay, that's odd. Implementation difference maybe?

I wonder if this test doesn't have the stride a multiple of 16B. That's required by the spec, but I think recent drivers with the coopmat2 implementation made the driver more forgiving about this.

Oh yeah, that must be it. I'm running into undefined behaviour cause the stride isn't aligned. buf_a and buf_b use a stride of BK + 1 to avoid bank conflicts, and that would be 17 in my first tests. I missed that detail in the spec, it would have probably taken me a long time to find that. Thank you!

Copy link
Collaborator

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the number of comments.

ggml/src/ggml-vulkan/ggml-vulkan.cpp Show resolved Hide resolved
ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
// Try to find a non-graphics compute queue and transfer-focused queues
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
// Try to find a queue that supports compute and transfer-focused queues
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlags{}, -1, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it wrong to prefer a non-graphics compute queue? My assumption was that by doing this I would keep it more separate from graphics tasks the GPU might be doing simultaneously.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't wrong, I think performance will be the same for either type of queue. I can change it back if you want.

ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
@jeffbolznv
Copy link
Collaborator Author

I was originally thinking we'd not merge this until the next Vulkan SDK is released, but I recently realized that Android has a separate SDK on a different release schedule, so the code needs to build against older headers for a while regardless. I've updated things so it can build against older or newer headers. For the shader compiles, it checks the vulkan header for the presence of the extension and assumes glslc will support coopmat2 if the extension is present in the headers. I don't know of a better way to check for support.

I've removed the debug code and will remove the "draft" label, I think this is ready for review/merge. I'm fine with waiting for #10597 to be merged first.

@jeffbolznv jeffbolznv changed the title Draft: vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and FlashAttention2 vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and FlashAttention2 Dec 4, 2024
@jeffbolznv jeffbolznv requested a review from 0cc4m December 4, 2024 15:13
ggml/src/ggml-vulkan/ggml-vulkan.cpp Outdated Show resolved Hide resolved
@0cc4m
Copy link
Collaborator

0cc4m commented Dec 4, 2024

I've removed the debug code and will remove the "draft" label, I think this is ready for review/merge. I'm fine with waiting for #10597 to be merged first.

I'll do a review in a few hours. If no bigger issues come up, I think we can merge this one first, afterall my PR was inspired by this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants