Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] Enforce the UT Coverity and add benchmark for transpose #2421

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ if(BUILD_PRIMS_BENCH)
linalg/reduce_rows_by_key.cu
linalg/reduce.cu
linalg/sddmm.cu
linalg/transpose.cu
main.cpp
)

Expand Down
85 changes: 85 additions & 0 deletions cpp/bench/prims/linalg/transpose.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct transpose_input {
IdxT rows, cols;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const transpose_input<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols;
return os;
}

template <typename T, typename IdxT, typename Layout>
struct TransposeBench : public fixture {
TransposeBench(const transpose_input<IdxT>& p)
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto input_view =
raft::make_device_matrix_view<T, IdxT, Layout>(in.data(), params.rows, params.cols);
auto output_view = raft::make_device_vector_view<T, IdxT, Layout>(out.data(), params.rows);
raft::linalg::transpose(handle,
input_view.data_handle(),
output_view.data_handle(),
params.rows,
params.cols,
handle.get_stream());
});
}

private:
transpose_input<IdxT> params;
rmm::device_uvector<T> in, out;
}; // struct TransposeBench

const std::vector<transpose_input<int>> transpose_inputs_i32 =
raft::util::itertools::product<transpose_input<int>>({10, 128, 256, 512, 1024},
{10000, 100000, 1000000});

RAFT_BENCH_REGISTER((TransposeBench<float, int, raft::row_major>), "", transpose_inputs_i32);
RAFT_BENCH_REGISTER((TransposeBench<half, int, raft::row_major>), "", transpose_inputs_i32);

RAFT_BENCH_REGISTER((TransposeBench<float, int, raft::col_major>), "", transpose_inputs_i32);
RAFT_BENCH_REGISTER((TransposeBench<half, int, raft::col_major>), "", transpose_inputs_i32);

} // namespace raft::bench::linalg
67 changes: 53 additions & 14 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ template <typename IndexType, int TILE_DIM, int BLOCK_ROWS>
RAFT_KERNEL transpose_half_kernel(IndexType n_rows,
IndexType n_cols,
const half* __restrict__ in,
half* __restrict__ out)
half* __restrict__ out,
const IndexType stride_in,
const IndexType stride_out)
{
__shared__ half tile[TILE_DIM][TILE_DIM + 1];

Expand All @@ -49,7 +51,7 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows,

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < n_cols && (y + j) < n_rows) {
tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * n_cols + x]);
tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * stride_in + x]);
Copy link
Contributor

@achirkin achirkin Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not the part of the change, but it's advisable to use raft's helpers instread of __xxx functions unless you need a specific cache behavior (for which, maybe, we should add more helpers?..)

Suggested change
tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * stride_in + x]);
tile[threadIdx.y + j][threadIdx.x] = raft::ldg(&in[(y + j) * stride_in + x]);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @achirkin , thank you for your suggestion! Yeah, you're right! Because cublas<t>geam()does not support the half. About the cublasLtMatrixTransform. Let me see how many I can change. (would you like to suggest changing this PR or the next separate one to change all of transpose ? )

Copy link
Member Author

@rhdong rhdong Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @achirkin, I just tested the cublasLtMatrixTransform benchmark and found that performance would be lower than the current implementation by around 10~20%. So, could I keep the current implementation temporarily:

                                                            transpose_half_kernel   cublasLtMatrixTransform    rows* cols
TransposeBench<half, int, raft::row_major>/0/manual_time    0.006 ms                0.009 ms                   10#10000
TransposeBench<half, int, raft::row_major>/1/manual_time    0.016 ms                0.017 ms                   10#100000
TransposeBench<half, int, raft::row_major>/2/manual_time    0.122 ms                0.108 ms                   10#1000000
TransposeBench<half, int, raft::row_major>/3/manual_time    0.011 ms                0.014 ms                   128#10000
TransposeBench<half, int, raft::row_major>/4/manual_time    0.084 ms                0.091 ms                   128#100000
TransposeBench<half, int, raft::row_major>/5/manual_time    0.762 ms                0.845 ms                   128#1000000
TransposeBench<half, int, raft::row_major>/6/manual_time    0.022 ms                0.023 ms                   256#10000
TransposeBench<half, int, raft::row_major>/7/manual_time    0.156 ms                0.186 ms                   256#100000
TransposeBench<half, int, raft::row_major>/8/manual_time     1.53 ms                 1.80 ms                   256#1000000
TransposeBench<half, int, raft::row_major>/9/manual_time    0.035 ms                0.041 ms                   512#10000
TransposeBench<half, int, raft::row_major>/10/manual_time   0.310 ms                0.395 ms                   512#100000
TransposeBench<half, int, raft::row_major>/11/manual_time    3.09 ms                 3.91 ms                   512#1000000
TransposeBench<half, int, raft::row_major>/12/manual_time   0.073 ms                0.076 ms                   1024#10000
TransposeBench<half, int, raft::row_major>/13/manual_time   0.642 ms                0.796 ms                   1024#100000
TransposeBench<half, int, raft::row_major>/14/manual_time    6.29 ms                 7.94 ms                   1024#1000000

TransposeBench<half, int, raft::col_major>/0/manual_time    0.006 ms                0.009 ms                   10#10000
TransposeBench<half, int, raft::col_major>/1/manual_time    0.017 ms                0.017 ms                   10#100000
TransposeBench<half, int, raft::col_major>/2/manual_time    0.125 ms                0.109 ms                   10#1000000
TransposeBench<half, int, raft::col_major>/3/manual_time    0.011 ms                0.014 ms                   128#10000
TransposeBench<half, int, raft::col_major>/4/manual_time    0.084 ms                0.091 ms                   128#100000
TransposeBench<half, int, raft::col_major>/5/manual_time    0.762 ms                0.847 ms                   128#1000000
TransposeBench<half, int, raft::col_major>/6/manual_time    0.022 ms                0.023 ms                   256#10000
TransposeBench<half, int, raft::col_major>/7/manual_time    0.156 ms                0.186 ms                   256#100000
TransposeBench<half, int, raft::col_major>/8/manual_time     1.53 ms                 1.80 ms                   256#1000000
TransposeBench<half, int, raft::col_major>/9/manual_time    0.035 ms                0.041 ms                   512#10000
TransposeBench<half, int, raft::col_major>/10/manual_time   0.310 ms                0.396 ms                   512#100000
TransposeBench<half, int, raft::col_major>/11/manual_time    3.09 ms                 3.91 ms                   512#1000000
TransposeBench<half, int, raft::col_major>/12/manual_time   0.073 ms                0.076 ms                   1024#10000
TransposeBench<half, int, raft::col_major>/13/manual_time   0.643 ms                0.796 ms                   1024#100000
TransposeBench<half, int, raft::col_major>/14/manual_time    6.29 ms                 7.95 ms                   1024#1000000

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhdong lets create an issue for Artem’s suggested change and reference it in a todo comment in the corresponding kernel in the code. I think we should investigate this for sure so that we are utilizing math libs where at all possible (and not having to maintain both math libs and our own custom impls) but I do not think the further investigation should hold up this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhdong lets create an issue for Artem’s suggested change and reference it in a todo comment in the corresponding kernel in the code. I think we should investigate this for sure so that we are utilizing math libs where at all possible (and not having to maintain both math libs and our own custom impls) but I do not think the further investigation should hold up this PR.

Yeah, here it is: #2436

}
}
__syncthreads();
Expand All @@ -59,17 +61,41 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows,

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < n_rows && (y + j) < n_cols) {
out[(y + j) * n_rows + x] = tile[threadIdx.x][threadIdx.y + j];
out[(y + j) * stride_out + x] = tile[threadIdx.x][threadIdx.y + j];
}
}
__syncthreads();
}
}
}

/**
* @brief Transposes a matrix stored in row-major order.
*
* This function transposes a matrix of half-precision floating-point numbers (`half`).
* Both the input (`in`) and output (`out`) matrices are assumed to be stored in row-major order.
*
* @tparam IndexType The type used for indexing the matrix dimensions (e.g., int).
* @param handle The RAFT resource handle which contains resources.
* @param n_rows The number of rows in the input matrix.
* @param n_cols The number of columns in the input matrix.
* @param in Pointer to the input matrix in row-major order.
* @param out Pointer to the output matrix in row-major order, where the transposed matrix will be
* stored.
* @param stride_in The stride (number of elements between consecutive rows) for the input matrix.
* Default is 1, which means the input matrix is contiguous in memory.
* @param stride_out The stride (number of elements between consecutive rows) for the output matrix.
* Default is 1, which means the output matrix is contiguous in memory.
*/

template <typename IndexType>
void transpose_half(
raft::resources const& handle, IndexType n_rows, IndexType n_cols, const half* in, half* out)
void transpose_half(raft::resources const& handle,
IndexType n_rows,
IndexType n_cols,
const half* in,
half* out,
const IndexType stride_in = 1,
const IndexType stride_out = 1)
{
if (n_cols == 0 || n_rows == 0) return;
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -100,8 +126,13 @@ void transpose_half(

dim3 grids(adjusted_grid_x, adjusted_grid_y);

transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out);
if (stride_in > 1 || stride_out > 1) {
transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out, stride_in, stride_out);
} else {
transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out, n_cols, n_rows);
}

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand All @@ -118,7 +149,7 @@ void transpose(raft::resources const& handle,
int out_n_cols = n_rows;

if constexpr (std::is_same_v<math_t, half>) {
transpose_half(handle, out_n_rows, out_n_cols, in, out);
transpose_half(handle, n_cols, n_rows, in, out);
} else {
cublasHandle_t cublas_h = resource::get_cublas_handle(handle);
RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream));
Expand Down Expand Up @@ -195,9 +226,13 @@ void transpose_row_major_impl(
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
transpose_half<IndexType>(handle, out_n_cols, out_n_rows, in.data_handle(), out.data_handle());
transpose_half<IndexType>(handle,
in.extent(0),
in.extent(1),
in.data_handle(),
out.data_handle(),
in.stride(0),
out.stride(0));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
Expand Down Expand Up @@ -233,9 +268,13 @@ void transpose_col_major_impl(
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
transpose_half<IndexType>(handle, out_n_rows, out_n_cols, in.data_handle(), out.data_handle());
transpose_half<IndexType>(handle,
in.extent(1),
in.extent(0),
in.data_handle(),
out.data_handle(),
in.stride(1),
out.stride(1));
}

}; // end namespace detail
Expand Down
Loading
Loading