Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] How to make the performance of cutlass kernel achieve what cutlass_profiler claims? #1298

Closed
Miroier opened this issue Jan 10, 2024 · 12 comments
Labels
question Question

Comments

@Miroier
Copy link

Miroier commented Jan 10, 2024

What is your question?

Any help is greatly appreciated!

After run cutlass_profiler --operation=Gemm --m=128 --n=128 --k=16384 --A=f64:row --B=f64:row --op_class=tensorop --split_k_mode=parallel --split_k_slices=8 --output=dgemm_cutlass.csv, I get a ideal result like this:

image

My kernel implementation is in dgemm_cutlass.cu as follows:

#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/core_io.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/device_memory.h"

#define CUTLASS_CHECK(status)                                                                    \
  {                                                                                              \
    cutlass::Status error = status;                                                              \
    if (error != cutlass::Status::kSuccess) {                                                    \
      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at line: " << __LINE__ \
                << std::endl;                                                                    \
      exit(EXIT_FAILURE);                                                                        \
    }                                                                                            \
  }

namespace cutlass {
namespace gemm {
namespace device {

template <
    /// Element type for A matrix operand
    typename ElementA_,
    /// Layout type for A matrix operand
    typename LayoutA_,
    /// Element type for B matrix operand
    typename ElementB_,
    /// Layout type for B matrix operand
    typename LayoutB_,
    /// Element type for C and D matrix operands
    typename ElementC_,
    /// Layout type for C and D matrix operands
    typename LayoutC_,
    /// Element type for internal accumulation
    typename ElementAccumulator_ = ElementC_,
    /// Operator class tag
    typename OperatorClass_ = arch::OpClassSimt,
    /// Tag indicating architecture to tune for.  This is the minimum SM that
      /// supports the intended feature. The device kernel can be built
      /// targeting any SM larger than this number.
    typename ArchTag_ = arch::Sm70,
    /// Threadblock-level tile size (concept: GemmShape)
    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::ThreadblockShape,
    /// Warp-level tile size (concept: GemmShape)
    typename WarpShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::WarpShape,
    /// Instruction-level tile size (concept: GemmShape)
    typename InstructionShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::InstructionShape,
    /// Epilogue output operator
    typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::EpilogueOutputOp,
    /// Epilogue output operator
    typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
        ElementAccumulator_,
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementAccumulator_,
                                 ElementAccumulator_>::EpilogueOutputOp::kCount,
        ElementAccumulator_>,
    /// Reduction operator
    typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
        ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,
        EpilogueOutputOp_::kCount>,
    /// Threadblock-level swizzling operator
    typename ThreadblockSwizzle_ =
        threadblock::GemmSplitKHorizontalThreadblockSwizzle,
    /// Number of stages used in the pipelined mainloop
    int Stages =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kStages,
    /// Access granularity of A matrix in units of elements
    int kAlignmentA =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentA,
    /// Access granularity of B matrix in units of elements
    int kAlignmentB =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentB,
    /// Operation performed by GEMM
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator>
using GemmSplitKWithStages4 = typename cutlass::gemm::device::GemmSplitKParallel<
    ElementA_,
    LayoutA_,
    ElementB_,
    LayoutB_,
    ElementC_,
    LayoutC_,
    ElementAccumulator_,
    OperatorClass_,
    ArchTag_,
    ThreadblockShape_,
    WarpShape_,
    InstructionShape_,
    EpilogueOutputOp_,
    ConvertScaledOp_,
    ReductionOp_,
    ThreadblockSwizzle_,
    4,
    kAlignmentA,
    kAlignmentB,
    Operator_
>;
}}}

extern "C" {
int dgemm(cudaStream_t stream, double *out, const double *x, const double *y, int M, int N, int K)
{
    using Gemm = typename cutlass::gemm::device::GemmSplitKWithStages4<
        // Data type and layout of operand A
        double, cutlass::layout::RowMajor,
        // Data type and layout of operand B
        double, cutlass::layout::RowMajor,
        // Data type and layout of operand C
        double, cutlass::layout::RowMajor,
        // Data type of accumulator
        double,
        // Class of operation
        cutlass::arch::OpClassTensorOp,
        // Compute capability of the target kernel
        cutlass::arch::Sm80,
        // Threadblock tile shape
        cutlass::gemm::GemmShape<64, 32, 16>,
        // Warp tile shape
        cutlass::gemm::GemmShape<32, 16, 16>,
        // Instruction shape
        cutlass::gemm::GemmShape<8, 8, 4>
    >;

    Gemm gemm_op;
    cutlass::Status status;

    double alpha = 1.0;
    double beta = 0.0;
    int split_k_slices = 16; // Split K dimension into split_k_slices partitions

    typename Gemm::Arguments arguments{
        {M, N, K},
        {x, K},
        {y, N},
        {out, N},
        {out, N},
        {alpha, beta},
        split_k_slices
    };

    size_t workspace_size = Gemm::get_workspace_size(arguments);

    cutlass::device_memory::allocation<double> workspace(workspace_size);

    status = gemm_op.initialize(arguments, workspace.get());
    CUTLASS_CHECK(status);

    status = gemm_op();
    CUTLASS_CHECK(status);

    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess)
        return 1;
    return 0;
}
}

I will run the following command to compile this kernel into a dynamic library

nvcc -shared -o ../libcupy_helper.so dgemm_cutlass.cu -arch=sm_80 -Xcompiler -fPIC -Icutlass/tools/util/include -Icutlass/include -std=c++17

Then load the dynamic library in python

# omit other code ......

def benchmark_cutlass():
    m = 128
    n = 128
    k = 128*128
    a = cupy.random.rand(m, k)
    b = cupy.random.rand(k, n)

    repeat_count = 10

    perf = cupyx.profiler.benchmark(
        lambda x, y: cupy.dot(x, y), (a, b), n_repeat=repeat_count
    )
    total_flops = 2 * m * k * n
    elapsed = perf.gpu_times.mean()
    print(elapsed)
    print("CUPY cupy.dot GFLOPS: {}".format(total_flops / elapsed / 1e9))
    
    def f(x, y):
        m = x.shape[0]
        n = y.shape[1]
        out = cupy.empty((m, n))
        stream = cupy.cuda.get_current_stream()
        err = libcupy_helper.dgemm(
            ctypes.cast(stream.ptr, ctypes.c_void_p),
            ctypes.cast(out.data.ptr, ctypes.c_void_p),
            ctypes.cast(x.data.ptr, ctypes.c_void_p),
            ctypes.cast(y.data.ptr, ctypes.c_void_p),
            ctypes.c_int(m),
            ctypes.c_int(n),
            ctypes.c_int(k)
        )
        if err != 0:
            raise RuntimeError('failed in dgemm kernel')
        return out  

    perf = cupyx.profiler.benchmark(
        f, (a, b), n_repeat=repeat_count
    )
    total_flops = 2 * m * k * n
    elapsed = perf.gpu_times.mean()
    print(elapsed)
    print("cutlass GFLOPS: {}".format(total_flops / elapsed / 1e9))
    
    ans = cupy.dot(a, b)
    res = f(a, b)
    assert(cupy.linalg.norm(res - ans))

if __name__ == '__main__':
    benchmark_cutlass()

Finally got this result.

7.690240144729616e-05
CUPY cupy.dot GFLOPS: 6981.198270744978
0.0009633792042732239
cutlass GFLOPS: 557.2789090927254

I'd like to know why the performance is so much lower than what cutlass_profiler gets. Am I making any mistakes in using it?

@hwu36
Copy link
Collaborator

hwu36 commented Jan 10, 2024

cutlass profiler uses 8 way parallel splitk, you are using 16 way. could that be the difference? your nvcc cmd line is also different. you don't have any -O.

@Miroier
Copy link
Author

Miroier commented Jan 10, 2024

cutlass profiler uses 8 way parallel splitk, you are using 16 way. could that be the difference? your nvcc cmd line is also different. you don't have any -O.

Sorry for the difference, but after I changed split_k_slices = 16 to split_k_slices = 8 and added -O3 to the nvcc compile options, the result was virtually no different.

@hwu36
Copy link
Collaborator

hwu36 commented Jan 10, 2024

not just -O3. we have a very long list of flags.

cutlass profiler uses gemm_universal (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemm_universal.h) with mode = kGemmSplitKParallel to run parallel splitk. see

params.mode == GemmUniversalMode::kGemmSplitKParallel) {

@hwu36
Copy link
Collaborator

hwu36 commented Jan 10, 2024

also cutlass profiler runs multiple times after some warmup runs. yours seems to be a cold run.

@d-k-b
Copy link
Collaborator

d-k-b commented Jan 10, 2024

See https://github.com/NVIDIA/cutlass/blob/main/CUDA.cmake#L318 for applying the necessary compiler options. If you are using CMake, you can just call cutlass_add_library() or cutlass_add_exectuble() to get the appropriate setup for your build.

@Miroier
Copy link
Author

Miroier commented Jan 11, 2024

not just -O3. we have a very long list of flags.

cutlass profiler uses gemm_universal (main/include/cutlass/gemm/device/gemm_universal.h) with mode = kGemmSplitKParallel to run parallel splitk. see

params.mode == GemmUniversalMode::kGemmSplitKParallel) {

thank you for your reply, I will use cmake according to the comments below.

you mean I should use cutlass/gemm/kernel/gemm_universal.h instead of cutlass/gemm/device/gemm_splitk_parallel.h(which I found in examples/06_splitK_gemm)?

warmup is inside cupyx.profile.benchmark and its default value is 10 due to https://github.com/cupy/cupy/blob/v12.3.0/cupyx/profiler/_time.py#L82

@hwu36
Copy link
Collaborator

hwu36 commented Jan 11, 2024

you mean I should use cutlass/gemm/kernel/gemm_universal.h instead of cutlass/gemm/device/gemm_splitk_parallel.h(which I found in examples/06_splitK_gemm)?

I think these two should have the same performance. but if you need to run the exact same thing as the profiler, you can try to use gemm_universal.

@Miroier
Copy link
Author

Miroier commented Jan 11, 2024

See main/CUDA.cmake#L318 for applying the necessary compiler options. If you are using CMake, you can just call cutlass_add_library() or cutlass_add_exectuble() to get the appropriate setup for your build.

Sorry, I don't know cmake very well and would like to ask you a question. The content of my CMakeLists.txt is as follows

cmake_minimum_required(VERSION 3.19 FATAL_ERROR)

project(cupy_helper LANGUAGES CXX CUDA)

include(CUDA.cmake) # Copy function cutlass_apply_standard_compile_options and cutlass_apply_cuda_gencode_flags in CMakeLists.txt to CUDA.cmake

set(CUDA_ARCHITECTURES 80)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fPIC")

include_directories(cutlass/tools/util/include cutlass/include)

# add_library(cupy_helper SHARED dgemm_cutlass2.cu) # ok, .so
# cutlass_add_library(cupy_helper dgemm_cutlass2.cu) # ok, .a 
cutlass_add_library(cupy_helper SHARED dgemm_cutlass2.cu) # error, /usr/bin/ld: cannot find -lCUTLASS

set_target_properties(cupy_helper PROPERTIES CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})

I encountered this error while using cutlass_add_library: /usr/bin/ld: cannot find -lCUTLASS

cutlass is a header-only library, so there should be no libcutlass.so for me to link to, how should I fix this error?

@d-k-b
Copy link
Collaborator

d-k-b commented Jan 11, 2024

That looks very close. The part you are missing is importing CUTLASS. The example 60_cutlass_import shows how to do this for an installed CUTLASS package.

https://github.com/NVIDIA/cutlass/blob/main/examples/60_cutlass_import/CMakeLists.txt#L47

If you want to build against CUTLASS completely from source, you can just treat it like a subdirectory in your build but put it in header-only mode.

... 
set(CUTLASS_ENABLE_HEADERS_ONLY ON)
add_subdirectory(cutlass)

cutlass_add_library(cupy_helper SHARED dgemm_cutlass2.cu)
target_link_libraries(cupy_helper PUBLIC nvidia::cutlass::cutlass)
...

@Miroier
Copy link
Author

Miroier commented Jan 14, 2024

Using cutlass_add_library still throws an error, but I've always assumed that the reason for the inefficiency was that it wasn't compiled and linked correctly.

Recently I noticed that when I commented out the parts of the code that allocated memory and instead allocated the memory in python and passed it into the cuda program as a pointer, the gflops went up to 5000+ (cupy is able to save time allocating memory through techniques like memory pooling)

So I'm wondering if cutlass is not performing as expected because there are other additional overheads in the code, such as gemm_op.initialize(arguments, cupy_workspace), where arguments depend on the function's arguments, so there's no way to omit them.

I'm wondering if I'm on the right track (the problem is in the code, not in cmake) and how I can save the overhead of gemm_op.initialize.

I'll post the full code below to avoid having to scroll through the context.

dgemm_cutlass2.cu:

#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/core_io.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/device_memory.h"

#define CUTLASS_CHECK(status)                                                                    \
  {                                                                                              \
    cutlass::Status error = status;                                                              \
    if (error != cutlass::Status::kSuccess) {                                                    \
      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at line: " << __LINE__ \
                << std::endl;                                                                    \
      exit(EXIT_FAILURE);                                                                        \
    }                                                                                            \
  }

namespace cutlass {
namespace gemm {
namespace device {

template <
    /// Element type for A matrix operand
    typename ElementA_,
    /// Layout type for A matrix operand
    typename LayoutA_,
    /// Element type for B matrix operand
    typename ElementB_,
    /// Layout type for B matrix operand
    typename LayoutB_,
    /// Element type for C and D matrix operands
    typename ElementC_,
    /// Layout type for C and D matrix operands
    typename LayoutC_,
    /// Element type for internal accumulation
    typename ElementAccumulator_ = ElementC_,
    /// Operator class tag
    typename OperatorClass_ = arch::OpClassSimt,
    /// Tag indicating architecture to tune for.  This is the minimum SM that
      /// supports the intended feature. The device kernel can be built
      /// targeting any SM larger than this number.
    typename ArchTag_ = arch::Sm70,
    /// Threadblock-level tile size (concept: GemmShape)
    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::ThreadblockShape,
    /// Warp-level tile size (concept: GemmShape)
    typename WarpShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::WarpShape,
    /// Instruction-level tile size (concept: GemmShape)
    typename InstructionShape_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::InstructionShape,
    /// Epilogue output operator
    typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::EpilogueOutputOp,
    /// Epilogue output operator
    typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
        ElementAccumulator_,
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementAccumulator_,
                                 ElementAccumulator_>::EpilogueOutputOp::kCount,
        ElementAccumulator_>,
    /// Reduction operator
    typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
        ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,
        EpilogueOutputOp_::kCount>,
    /// Threadblock-level swizzling operator
    typename ThreadblockSwizzle_ =
        threadblock::GemmSplitKHorizontalThreadblockSwizzle,
    /// Number of stages used in the pipelined mainloop
    int Stages =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kStages,
    /// Access granularity of A matrix in units of elements
    int kAlignmentA =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentA,
    /// Access granularity of B matrix in units of elements
    int kAlignmentB =
        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
                                 ElementC_, ElementAccumulator_>::kAlignmentB,
    /// Operation performed by GEMM
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator>
using GemmSplitKWithStages4 = typename cutlass::gemm::device::GemmSplitKParallel<
    ElementA_,
    LayoutA_,
    ElementB_,
    LayoutB_,
    ElementC_,
    LayoutC_,
    ElementAccumulator_,
    OperatorClass_,
    ArchTag_,
    ThreadblockShape_,
    WarpShape_,
    InstructionShape_,
    EpilogueOutputOp_,
    ConvertScaledOp_,
    ReductionOp_,
    ThreadblockSwizzle_,
    4,
    kAlignmentA,
    kAlignmentB,
    Operator_
>;
}}}

using Gemm = typename cutlass::gemm::device::GemmSplitKWithStages4<
    // Data type and layout of operand A
    double, cutlass::layout::RowMajor,
    // Data type and layout of operand B
    double, cutlass::layout::RowMajor,
    // Data type and layout of operand C
    double, cutlass::layout::RowMajor,
    // Data type of accumulator
    double,
    // Class of operation
    cutlass::arch::OpClassTensorOp,
    // Compute capability of the target kernel
    cutlass::arch::Sm80,
    // Threadblock tile shape
    cutlass::gemm::GemmShape<64, 32, 16>,
    // Warp tile shape
    cutlass::gemm::GemmShape<32, 16, 16>,
    // Instruction shape
    cutlass::gemm::GemmShape<8, 8, 4>
>;

Gemm gemm_op;
cutlass::Status status;

extern "C" {
int dgemm(cudaStream_t stream, double *out, const double *x, const double *y, uint8_t *cupy_workspace, int M, int N, int K)
{
    double alpha = 1.0;
    double beta = 0.0;
    int split_k_slices = 8; // Split K dimension into split_k_slices partitions

    typename Gemm::Arguments arguments{
        {M, N, K},
        {x, K},
        {y, N},
        {out, N},
        {out, N},
        {alpha, beta},
        split_k_slices
    };

    // size_t workspace_size = Gemm::get_workspace_size(arguments);
    // printf("workspace_size = %lu\n", workspace_size);

    // cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    // status = gemm_op.initialize(arguments, workspace.get());
    status = gemm_op.initialize(arguments, cupy_workspace);
    CUTLASS_CHECK(status);

    status = gemm_op();
    CUTLASS_CHECK(status);

    cudaError_t err = cudaGetLastError();
    // printf("%s\n", cudaGetErrorString(err));
    if (err != cudaSuccess)
        return 1;
    return 0;
}
}

CMakeLists.txt:

cmake_minimum_required(VERSION 3.19 FATAL_ERROR)

project(cupy_helper LANGUAGES CXX CUDA)

find_package(NvidiaCutlass REQUIRED)
message(STATUS "CUTLASS: ${NvidiaCutlass_DIR}")

# include(CUDA.cmake) # Copy function cutlass_apply_standard_compile_options and cutlass_apply_cuda_gencode_flags in CMakeLists.txt to CUDA.cmake

set(CUDA_ARCHITECTURES 80)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fPIC")

# include_directories(cutlass/tools/util/include cutlass/include)

# set(CUTLASS_ENABLE_HEADERS_ONLY ON)
# add_subdirectory(cutlass)

add_library(cupy_helper SHARED dgemm_cutlass2.cu) # ok, .so
# cutlass_add_library(cupy_helper dgemm_cutlass2.cu) # ok, .a 
# cutlass_add_library(cupy_helper SHARED dgemm_cutlass2.cu) # error, /usr/bin/ld: cannot find -lCUTLASS
target_link_libraries(cupy_helper PUBLIC nvidia::cutlass::cutlass nvidia::cutlass::library)

set_target_properties(cupy_helper PROPERTIES CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})

cupy_helper.py:

# omit other code...

def benchmark_cutlass():
    m = 128
    n = m
    k = 128*128
    a = cupy.random.rand(m, k)
    b = cupy.random.rand(k, n)
    c = cupy.random.rand(m, n)
    d = cupy.zeros((m ,n))
    
    # warpup_count = 10
    repeat_count = 10

    perf = cupyx.profiler.benchmark(
        lambda x, y: cupy.dot(x, y), (a, b), n_repeat=repeat_count
    )
    total_flops = 2 * m * k * n
    elapsed = perf.gpu_times.mean()
    print(perf.gpu_times)
    print(elapsed)
    # res_cupy = total_flops / elapsed / 1e9
    print("CUPY cupy.dot GFLOPS: {}".format(total_flops / elapsed / 1e9))
    
    # cutlass
    def f(x, y, out=None):
        m = x.shape[0]
        n = y.shape[1]
        out = cupy.empty((m, n))
        workspace = cupy.empty((int(1048576 / 8))) # 1048576 is the value of workspace_size
        stream = cupy.cuda.get_current_stream()
        err = libcupy_helper.dgemm(
            ctypes.cast(stream.ptr, ctypes.c_void_p),
            ctypes.cast(out.data.ptr, ctypes.c_void_p),
            ctypes.cast(x.data.ptr, ctypes.c_void_p),
            ctypes.cast(y.data.ptr, ctypes.c_void_p),
            ctypes.cast(workspace.data.ptr, ctypes.c_void_p),
            ctypes.c_int(m),
            ctypes.c_int(n),
            ctypes.c_int(k)
        )
        if err != 0:
            raise RuntimeError('failed in dgemm kernel')
        return out  

    perf = cupyx.profiler.benchmark(
        f, (a, b), n_repeat=repeat_count
    )
    total_flops = 2 * m * k * n
    elapsed = perf.gpu_times.mean()
    print(perf.gpu_times)
    print(elapsed)
    # res_cupy = total_flops / elapsed / 1e9
    print("cutlass GFLOPS: {}".format(total_flops / elapsed / 1e9))
    
    ans = cupy.dot(a, b)
    res = f(a, b)
    assert(cupy.linalg.norm(res - ans))  

if __name__ == '__main__':
    benchmark_cutlass()

Finally, the output of the program is:

[[6.65600002e-05 7.25760013e-05 6.46720007e-05 6.26240000e-05
  6.28800020e-05 6.26880005e-05 6.23040013e-05 7.12959990e-05
  6.43199980e-05 6.32319972e-05]]
6.531520001590252e-05
CUPY cupy.dot GFLOPS: 8219.693300629657
[[9.46239978e-05 9.31840017e-05 9.16799977e-05 9.67999995e-05
  9.46879983e-05 9.21280012e-05 9.12960023e-05 9.16159973e-05
  9.12320018e-05 9.26719978e-05]]
9.299199953675271e-05
cutlass GFLOPS: 5773.30216227704

@hwu36
Copy link
Collaborator

hwu36 commented Jan 16, 2024

cutlass profiler does not include initialize time, it only measures run time. as to gemm, you know the tensor capacity before hand, you don't need to initialize right before calling run. memory pooling you mentioned is the right way to do.

@Miroier
Copy link
Author

Miroier commented Jan 17, 2024

I wrote a single cuda program to test the performance, and the flops I got were very similar to the results given by the cutlass profiler, so I think the problem is in the python level, and I probably shouldn't have raised an issue because of that.

@Miroier Miroier closed this as completed Jan 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

4 participants