-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] How to make the performance of cutlass kernel achieve what cutlass_profiler claims? #1298
Comments
cutlass profiler uses 8 way parallel splitk, you are using 16 way. could that be the difference? your nvcc cmd line is also different. you don't have any |
Sorry for the difference, but after I changed |
not just cutlass profiler uses gemm_universal (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm_universal.h) with
|
also cutlass profiler runs multiple times after some warmup runs. yours seems to be a cold run. |
See https://github.com/NVIDIA/cutlass/blob/main/CUDA.cmake#L318 for applying the necessary compiler options. If you are using CMake, you can just call cutlass_add_library() or cutlass_add_exectuble() to get the appropriate setup for your build. |
thank you for your reply, I will use cmake according to the comments below. you mean I should use warmup is inside |
I think these two should have the same performance. but if you need to run the exact same thing as the profiler, you can try to use gemm_universal. |
Sorry, I don't know cmake very well and would like to ask you a question. The content of my CMakeLists.txt is as follows cmake_minimum_required(VERSION 3.19 FATAL_ERROR)
project(cupy_helper LANGUAGES CXX CUDA)
include(CUDA.cmake) # Copy function cutlass_apply_standard_compile_options and cutlass_apply_cuda_gencode_flags in CMakeLists.txt to CUDA.cmake
set(CUDA_ARCHITECTURES 80)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fPIC")
include_directories(cutlass/tools/util/include cutlass/include)
# add_library(cupy_helper SHARED dgemm_cutlass2.cu) # ok, .so
# cutlass_add_library(cupy_helper dgemm_cutlass2.cu) # ok, .a
cutlass_add_library(cupy_helper SHARED dgemm_cutlass2.cu) # error, /usr/bin/ld: cannot find -lCUTLASS
set_target_properties(cupy_helper PROPERTIES CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES}) I encountered this error while using cutlass_add_library: /usr/bin/ld: cannot find -lCUTLASS cutlass is a header-only library, so there should be no libcutlass.so for me to link to, how should I fix this error? |
That looks very close. The part you are missing is importing CUTLASS. The example 60_cutlass_import shows how to do this for an installed CUTLASS package. https://github.com/NVIDIA/cutlass/blob/main/examples/60_cutlass_import/CMakeLists.txt#L47 If you want to build against CUTLASS completely from source, you can just treat it like a subdirectory in your build but put it in header-only mode.
|
Using Recently I noticed that when I commented out the parts of the code that allocated memory and instead allocated the memory in python and passed it into the cuda program as a pointer, the gflops went up to 5000+ (cupy is able to save time allocating memory through techniques like memory pooling) So I'm wondering if cutlass is not performing as expected because there are other additional overheads in the code, such as I'm wondering if I'm on the right track (the problem is in the code, not in cmake) and how I can save the overhead of I'll post the full code below to avoid having to scroll through the context.
#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/core_io.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/device_memory.h"
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at line: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
namespace cutlass {
namespace gemm {
namespace device {
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Epilogue output operator
typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
ElementAccumulator_,
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementAccumulator_,
ElementAccumulator_>::EpilogueOutputOp::kCount,
ElementAccumulator_>,
/// Reduction operator
typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,
EpilogueOutputOp_::kCount>,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
threadblock::GemmSplitKHorizontalThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int kAlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int kAlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator>
using GemmSplitKWithStages4 = typename cutlass::gemm::device::GemmSplitKParallel<
ElementA_,
LayoutA_,
ElementB_,
LayoutB_,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ConvertScaledOp_,
ReductionOp_,
ThreadblockSwizzle_,
4,
kAlignmentA,
kAlignmentB,
Operator_
>;
}}}
using Gemm = typename cutlass::gemm::device::GemmSplitKWithStages4<
// Data type and layout of operand A
double, cutlass::layout::RowMajor,
// Data type and layout of operand B
double, cutlass::layout::RowMajor,
// Data type and layout of operand C
double, cutlass::layout::RowMajor,
// Data type of accumulator
double,
// Class of operation
cutlass::arch::OpClassTensorOp,
// Compute capability of the target kernel
cutlass::arch::Sm80,
// Threadblock tile shape
cutlass::gemm::GemmShape<64, 32, 16>,
// Warp tile shape
cutlass::gemm::GemmShape<32, 16, 16>,
// Instruction shape
cutlass::gemm::GemmShape<8, 8, 4>
>;
Gemm gemm_op;
cutlass::Status status;
extern "C" {
int dgemm(cudaStream_t stream, double *out, const double *x, const double *y, uint8_t *cupy_workspace, int M, int N, int K)
{
double alpha = 1.0;
double beta = 0.0;
int split_k_slices = 8; // Split K dimension into split_k_slices partitions
typename Gemm::Arguments arguments{
{M, N, K},
{x, K},
{y, N},
{out, N},
{out, N},
{alpha, beta},
split_k_slices
};
// size_t workspace_size = Gemm::get_workspace_size(arguments);
// printf("workspace_size = %lu\n", workspace_size);
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op.initialize(arguments, cupy_workspace);
CUTLASS_CHECK(status);
status = gemm_op();
CUTLASS_CHECK(status);
cudaError_t err = cudaGetLastError();
// printf("%s\n", cudaGetErrorString(err));
if (err != cudaSuccess)
return 1;
return 0;
}
}
cmake_minimum_required(VERSION 3.19 FATAL_ERROR)
project(cupy_helper LANGUAGES CXX CUDA)
find_package(NvidiaCutlass REQUIRED)
message(STATUS "CUTLASS: ${NvidiaCutlass_DIR}")
# include(CUDA.cmake) # Copy function cutlass_apply_standard_compile_options and cutlass_apply_cuda_gencode_flags in CMakeLists.txt to CUDA.cmake
set(CUDA_ARCHITECTURES 80)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fPIC")
# include_directories(cutlass/tools/util/include cutlass/include)
# set(CUTLASS_ENABLE_HEADERS_ONLY ON)
# add_subdirectory(cutlass)
add_library(cupy_helper SHARED dgemm_cutlass2.cu) # ok, .so
# cutlass_add_library(cupy_helper dgemm_cutlass2.cu) # ok, .a
# cutlass_add_library(cupy_helper SHARED dgemm_cutlass2.cu) # error, /usr/bin/ld: cannot find -lCUTLASS
target_link_libraries(cupy_helper PUBLIC nvidia::cutlass::cutlass nvidia::cutlass::library)
set_target_properties(cupy_helper PROPERTIES CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})
# omit other code...
def benchmark_cutlass():
m = 128
n = m
k = 128*128
a = cupy.random.rand(m, k)
b = cupy.random.rand(k, n)
c = cupy.random.rand(m, n)
d = cupy.zeros((m ,n))
# warpup_count = 10
repeat_count = 10
perf = cupyx.profiler.benchmark(
lambda x, y: cupy.dot(x, y), (a, b), n_repeat=repeat_count
)
total_flops = 2 * m * k * n
elapsed = perf.gpu_times.mean()
print(perf.gpu_times)
print(elapsed)
# res_cupy = total_flops / elapsed / 1e9
print("CUPY cupy.dot GFLOPS: {}".format(total_flops / elapsed / 1e9))
# cutlass
def f(x, y, out=None):
m = x.shape[0]
n = y.shape[1]
out = cupy.empty((m, n))
workspace = cupy.empty((int(1048576 / 8))) # 1048576 is the value of workspace_size
stream = cupy.cuda.get_current_stream()
err = libcupy_helper.dgemm(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(out.data.ptr, ctypes.c_void_p),
ctypes.cast(x.data.ptr, ctypes.c_void_p),
ctypes.cast(y.data.ptr, ctypes.c_void_p),
ctypes.cast(workspace.data.ptr, ctypes.c_void_p),
ctypes.c_int(m),
ctypes.c_int(n),
ctypes.c_int(k)
)
if err != 0:
raise RuntimeError('failed in dgemm kernel')
return out
perf = cupyx.profiler.benchmark(
f, (a, b), n_repeat=repeat_count
)
total_flops = 2 * m * k * n
elapsed = perf.gpu_times.mean()
print(perf.gpu_times)
print(elapsed)
# res_cupy = total_flops / elapsed / 1e9
print("cutlass GFLOPS: {}".format(total_flops / elapsed / 1e9))
ans = cupy.dot(a, b)
res = f(a, b)
assert(cupy.linalg.norm(res - ans))
if __name__ == '__main__':
benchmark_cutlass() Finally, the output of the program is:
|
cutlass profiler does not include |
I wrote a single cuda program to test the performance, and the flops I got were very similar to the results given by the cutlass profiler, so I think the problem is in the python level, and I probably shouldn't have raised an issue because of that. |
What is your question?
Any help is greatly appreciated!
After run
cutlass_profiler --operation=Gemm --m=128 --n=128 --k=16384 --A=f64:row --B=f64:row --op_class=tensorop --split_k_mode=parallel --split_k_slices=8 --output=dgemm_cutlass.csv
, I get a ideal result like this:My kernel implementation is in dgemm_cutlass.cu as follows:
I will run the following command to compile this kernel into a dynamic library
Then load the dynamic library in python
Finally got this result.
I'd like to know why the performance is so much lower than what cutlass_profiler gets. Am I making any mistakes in using it?
The text was updated successfully, but these errors were encountered: