-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] How do I do a matrix multiply inside a Cuda kernel? #1291
Comments
CUTLASS provides utilities to implement linear algebra at all layers of the GPU hierarchy. it is certainly possible to "fuse" GEMM into some existing kernel, however, it requires more work that using a GEMM kernel off the shelf. Can you describe your problem a bit more in detail? What is your problem shape, do you know it at compile time? data types? What are the layouts of input tensors and do you have control over them? what memory space do these input tensors live in? What architecture are you targeting? What are your perf goals? the more details you provide, the more we can help you. |
I made this diagram to make it easier to understand. I want Cutlass in order to implement the matrix multiplies needed for the neural net model. What I had in mind at the start of the program is to launch a bunch of grids and blocks. A minority of them would be dedicated to running the game itself. The rest of them would go past the game loop into the NN model loop. They would wait at a device wide barrier. They would be joined by the game threads when they request an action and the neural net model would execute. At the end of it, the game threads would return from the function invocation with the results, resuming the game, and the NN threads would go back to waiting on the barrier. With this kind of architecture, it should be possible to dedicate the entirety of the resources of the GPU to the neural net model. I think this is the way RL research should be done in the future. |
They will all be statically known, both the batch size, and the input and the output dimensions. The data types would be floats, also statically known. I'll have experiment to see which ones work best. As for memory space, they should be in global memory.
A RTX 4060 for now. I got it just so I could program this.
Hopefully something far better than doing the game in Python and using PyTorch for a NN model. That was slow as snail. I don't know exactly how many hands per second I'll get until I try this, but the PyTorch version that I did in 2021 was on the order of 10s of thousands of them per seconds, which was pretty horrible. This was with various kinds of models, so I think I was being bottlenecked by the bandwidth of transferring data between the CPU and the GPU. |
I forgot to address this. Basically, everything is known and controllable at compile time. |
If you need GEMM in the CUDA kernel, you may want to investigate using cuBLASDx. It should be available through early access very soon. Here is a link to the Math Libraries Early Access page. It sounds like your use case meets all the requirements. Take a look at the cuFFTDx documentation to get an idea of the API. |
I heard about that a month ago, but I left it out of her mind because it wasn't out yet. I'll check it out again. Thanks! |
Ok, so you just want to call an entire GEMM kernel but from within a CUDA kernel that you have. I think you can just stamp out a |
An example would make that a lot easier. The Cutlass library right now is not for the faint of heart. |
I unfortunatley do not have the time to write a full example, but I can try to stub something out for you here. the kernle layer gemm is simply the What I am suggesting is that you hijack this function, and use everything else off the shelf. you can run whatever code you want along side the kernel if you write something simlar /// Generic CUTLASS kernel template.
template <typename Operator>
__global__
void Kernel(typename Operator::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Operator::SharedStorage *shared_storage =
reinterpret_cast<typename Operator::SharedStorage *>(SharedStorageBase);
Operator op;
while(game_loop) {
// do whatever before
op(params, *shared_storage);
// do whatever after
}
}
|
Thanks. I'll give it a try. Let me just first get the Spiral tensors out of the way. I am not doing this in C++ so there is more work needed to get the fundamentals in place. |
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Operator::SharedStorage *shared_storage =
reinterpret_cast<typename Operator::SharedStorage *>(SharedStorageBase); I just want to ask, ahead of time before I try doing it. It feels like in this fragment, there is missing information. How will the gemm kernel know what the size of the shared storage is? |
we use dynamic shared memory allocation. check this blog https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/ dynamic shared memory allocation is required when the allocated shared memory is big (IIRC the bar is 48KB) |
I know what dynamic shared memory is. The only problem is that as written, the kernel won't know what the size of the shared array is. ...Nevermind my question. Now that I think about it, the size of the dynamically allocated shared memory is probably a part of |
when you launch the kernel, you need to put the shared memory size inside |
Sorry for the long delay, I managed to implement the tensor functionality in Spiral which will be necessary for passing the multidimensional arrays from the Python backend to the Cuda one, and now I am finally ready to start work on the ML library. Finding a matrix multiply that could work on the device is now my main priority. I registered for the Early Access for the CUDA Math library two weeks, but didn't get an answer so I am now looking at the Cutlass library again. As expected, I am having difficulty getting started. The basic device API example that Cutlass provides makes a lot of sense and resembles a CUBLAS call which I am familiar with. But the device API is a lot more complex.
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
class GemmUniversal<
Mma_,
Epilogue_,
ThreadblockSwizzle_,
void,
// 3.x kernels use the first template argument to define the ProblemShape
// We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API
cute::enable_if_t<not (cute::is_tuple<Mma_>::value || IsCutlass3ArrayKernel<Mma_>::value)>
> { As was suggested, I need to first figure out how to stamp one of these out. What should I be passing in for Secondly...
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr) I need a reference for how to construct the Arguments struct. There are indices, batch strides, regular strides, mode, problem size, batch count...
...In order to make this struct. |
One other thing. Looking at the I have to reraise this question, as my earlier assumption that it'd be passed as an argument wasn't valid. |
I am also concerned, is using an |
I found this in a different thread and it is a good example of how to construct a |
That having said, as things are, I am going to dig out a regular matrix multiply that uses shared memory and use that instead of Cutlass. I just can't afford to keep using all the time that should be going into building the ML library trying to figure out Cutlass instead. As it is a C++ project, since the authors have their own setup, I can't open it in something like VS to get the types via Intellisense or try running the examples and study them through the debugger. If the authors want to provide a concrete example of how to do what I asked in this thread, that would be appreciated, but I'll just let the issue linger for the time being. |
Do you want to adopt the 2.x API or 3.x API? It mostly depends on the architectures you are targeting. For pre-Hopper, I would recommend 2.x. As for how to generate the right set of kernels, I would encourage you to look at how cutlass library generates kernel configurations and emits the corresponding C++ to stamp out kernels. You can also lift the entire |
As for shared memory, the kernels in CUTLASS only ever work with a static amount of shared memory that the kernel knows about based on the template config. The kernel must be launched with that much dynamic smem, otherwise it will not run correctly. https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm_universal_adapter.h#L237 |
If the amount of shared memory needed is known by the Cutlass kernel template at compile time, then instead of declaring dynamic shared memory arrays, wouldn't it be better to declare static ones instead?
This would take a ton of time and I can't realistically run the huge C++ project examples in an IDE to step through the and get a sense of the program flow as well as the types since I am doing my development on Windows. The question of which version of the API to use is largely irrelevant to me, if there were good examples on how to use the kernel API for my use case, I'd use either one. What I'd really want is a BLAS-like API for Cutlass, that could actually be used on the device. I do not want to pick tile sizes or deal with similar kinds of low-level details. I am not interested in researching matrix multiplication itself but doing reinforcement learning purely on the GPU. Given my circumstances it would be a lot faster to pick the Cuda matmult sample, translate it into Spiral, and use that while I wait for CuBLASDx to be released. Putting in a lot more effort to make Cutlass work would only be something I'd be willing do if CuBLASDx is not out by the time I've built out everything else. It will take me a few months. |
No, because this often exceeds the portable 48K or so of smem, and we need to use the driver API to opt into larger smem carveouts which requires using dynamic smem.
If you look at the runtime cuBLASDx is designed to fuse much smaller matmuls into an existing program and is intended for use cases where you have small matmuls you can run out of shared memory. Is that the case for you? Sounds like you have large GEMMs to run instead of small matmuls. |
It seems that CuBLASDx preview is already out, so I'll close this issue. Thank you for taking the time to answer some of my questions. |
cuBLASDx examples have been posted https://github.com/NVIDIA/CUDALibrarySamples/tree/master/MathDx |
Rather than using either Cutlass or cuBLASDx, I ended spending the last two months building my own matrix multiply kernel as a part of the Spiral playlist. I think I did well, but the only thing that I couldn't quite get down is the async loads for the Ampere. No matter what I did, trying to interleave loads with the I wonder if there is something I don't understand about how to use async loads? I tried cuBLASDx, but for some reason the performance of it was horrible, and even without that, I thought that I could have done better (and I have) due to its interface constraints. There is an cuBLASDx example here and the compiled output in case @mnicely wants to have a look, |
your ptx needs to be very similar to cutlass ones. |
Hmmm, that actually isn't a bad idea, I mean to use Cutlass as a reference. I've only been using the CuPy matmult (from the cuBLAS library) as reference and it doesn't use tensor cores. It didn't use async loads when I looked into it either, though. That having said, I'll just upgrade my kernel so it uses the TMA and bulk async transfers when the 50xx cards come out and that will take care of the memory transfer issues. |
What is your question?
I want to implement a game + ML library directly on the GPU without any of the inefficiencies of going back and forth from host. But as far as I can tell, most of the Cutlass examples are calling from host. I thought that the kernel API would allow me to call from the device. But it doesn't seem like that is the case.
The one example that I found is calling from host.
I spent some time studying the examples, but they are all extremely complex, so that's why I'm asking here. How do I do the equivalent of using the BLAS Gemm inside the kernel?
The text was updated successfully, but these errors were encountered: