From 8501d6208399e08d900b04ace3e829d5636825fa Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 27 Aug 2024 15:52:08 -0700 Subject: [PATCH] [Opt] Enforce the UT Coverity and add benchmark for `transpose` - Fix the `transpose_half` is not compatible with the sub-matrix cases. --- cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/linalg/transpose.cu | 85 ++++ cpp/include/raft/linalg/detail/transpose.cuh | 67 ++- cpp/test/linalg/transpose.cu | 445 ++++++++++++++----- 4 files changed, 465 insertions(+), 133 deletions(-) create mode 100644 cpp/bench/prims/linalg/transpose.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 9d80cbaac2..52c63ad73b 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -132,6 +132,7 @@ if(BUILD_PRIMS_BENCH) linalg/reduce_rows_by_key.cu linalg/reduce.cu linalg/sddmm.cu + linalg/transpose.cu main.cpp ) diff --git a/cpp/bench/prims/linalg/transpose.cu b/cpp/bench/prims/linalg/transpose.cu new file mode 100644 index 0000000000..e60e50c125 --- /dev/null +++ b/cpp/bench/prims/linalg/transpose.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +template +struct transpose_input { + IdxT rows, cols; +}; + +template +inline auto operator<<(std::ostream& os, const transpose_input& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols; + return os; +} + +template +struct TransposeBench : public fixture { + TransposeBench(const transpose_input& p) + : params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream) + { + raft::random::RngState rng{1234}; + raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto input_view = + raft::make_device_matrix_view(in.data(), params.rows, params.cols); + auto output_view = raft::make_device_vector_view(out.data(), params.rows); + raft::linalg::transpose(handle, + input_view.data_handle(), + output_view.data_handle(), + params.rows, + params.cols, + handle.get_stream()); + }); + } + + private: + transpose_input params; + rmm::device_uvector in, out; +}; // struct TransposeBench + +const std::vector> transpose_inputs_i32 = + raft::util::itertools::product>({10, 128, 256, 512, 1024}, + {10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((TransposeBench), "", transpose_inputs_i32); +RAFT_BENCH_REGISTER((TransposeBench), "", transpose_inputs_i32); + +RAFT_BENCH_REGISTER((TransposeBench), "", transpose_inputs_i32); +RAFT_BENCH_REGISTER((TransposeBench), "", transpose_inputs_i32); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index ec60aacc9c..c5f0544b5c 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -38,7 +38,9 @@ template RAFT_KERNEL transpose_half_kernel(IndexType n_rows, IndexType n_cols, const half* __restrict__ in, - half* __restrict__ out) + half* __restrict__ out, + const IndexType stride_in, + const IndexType stride_out) { __shared__ half tile[TILE_DIM][TILE_DIM + 1]; @@ -49,7 +51,7 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows, for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { if (x < n_cols && (y + j) < n_rows) { - tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * n_cols + x]); + tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * stride_in + x]); } } __syncthreads(); @@ -59,7 +61,7 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows, for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { if (x < n_rows && (y + j) < n_cols) { - out[(y + j) * n_rows + x] = tile[threadIdx.x][threadIdx.y + j]; + out[(y + j) * stride_out + x] = tile[threadIdx.x][threadIdx.y + j]; } } __syncthreads(); @@ -67,9 +69,33 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows, } } +/** + * @brief Transposes a matrix stored in row-major order. + * + * This function transposes a matrix of half-precision floating-point numbers (`half`). + * Both the input (`in`) and output (`out`) matrices are assumed to be stored in row-major order. + * + * @tparam IndexType The type used for indexing the matrix dimensions (e.g., int). + * @param handle The RAFT resource handle which contains resources. + * @param n_rows The number of rows in the input matrix. + * @param n_cols The number of columns in the input matrix. + * @param in Pointer to the input matrix in row-major order. + * @param out Pointer to the output matrix in row-major order, where the transposed matrix will be + * stored. + * @param stride_in The stride (number of elements between consecutive rows) for the input matrix. + * Default is 1, which means the input matrix is contiguous in memory. + * @param stride_out The stride (number of elements between consecutive rows) for the output matrix. + * Default is 1, which means the output matrix is contiguous in memory. + */ + template -void transpose_half( - raft::resources const& handle, IndexType n_rows, IndexType n_cols, const half* in, half* out) +void transpose_half(raft::resources const& handle, + IndexType n_rows, + IndexType n_cols, + const half* in, + half* out, + const IndexType stride_in = 1, + const IndexType stride_out = 1) { if (n_cols == 0 || n_rows == 0) return; auto stream = resource::get_cuda_stream(handle); @@ -100,8 +126,13 @@ void transpose_half( dim3 grids(adjusted_grid_x, adjusted_grid_y); - transpose_half_kernel - <<>>(n_rows, n_cols, in, out); + if (stride_in > 1 || stride_out > 1) { + transpose_half_kernel + <<>>(n_rows, n_cols, in, out, stride_in, stride_out); + } else { + transpose_half_kernel + <<>>(n_rows, n_cols, in, out, n_cols, n_rows); + } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -118,7 +149,7 @@ void transpose(raft::resources const& handle, int out_n_cols = n_rows; if constexpr (std::is_same_v) { - transpose_half(handle, out_n_rows, out_n_cols, in, out); + transpose_half(handle, n_cols, n_rows, in, out); } else { cublasHandle_t cublas_h = resource::get_cublas_handle(handle); RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); @@ -195,9 +226,13 @@ void transpose_row_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { - auto out_n_rows = in.extent(1); - auto out_n_cols = in.extent(0); - transpose_half(handle, out_n_cols, out_n_rows, in.data_handle(), out.data_handle()); + transpose_half(handle, + in.extent(0), + in.extent(1), + in.data_handle(), + out.data_handle(), + in.stride(0), + out.stride(0)); } template @@ -233,9 +268,13 @@ void transpose_col_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { - auto out_n_rows = in.extent(1); - auto out_n_cols = in.extent(0); - transpose_half(handle, out_n_rows, out_n_cols, in.data_handle(), out.data_handle()); + transpose_half(handle, + in.extent(1), + in.extent(0), + in.data_handle(), + out.data_handle(), + in.stride(1), + out.stride(1)); } }; // end namespace detail diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index cbe869a9a5..22fc1c1d60 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -29,48 +29,104 @@ #include +#include + +namespace std { +template <> +struct is_floating_point : std::true_type {}; +} // namespace std + namespace raft { namespace linalg { template -struct TranposeInputs { +void initialize_array(T* data_h, size_t size) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + data_h[i] = __float2half(static_cast(dis(gen))); + } else { + data_h[i] = static_cast(dis(gen)); + } + } +} + +template +void cpu_transpose_row_major( + const T* input, T* output, int rows, int cols, int stride_in = -1, int stride_out = -1) +{ + stride_in = stride_in == -1 ? cols : stride_in; + stride_out = stride_out == -1 ? rows : stride_out; + if (stride_in) + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + output[j * stride_out + i] = input[i * stride_in + j]; + } + } +} + +template +void cpu_transpose_col_major( + const T* input, T* output, int rows, int cols, int stride_in = -1, int stride_out = -1) +{ + cpu_transpose_row_major(input, output, cols, rows, stride_in, stride_out); +} + +bool validate_half(const half* h_ref, const half* h_result, half tolerance, int len) +{ + bool success = true; + for (int i = 0; i < len; ++i) { + if (raft::abs(__half2float(h_result[i]) - __half2float(h_ref[i])) >= __half2float(tolerance)) { + success = false; + break; + } + if (!success) break; + } + return success; +} + +namespace transpose_regular_test { + +template +struct TransposeInputs { T tolerance; - int len; int n_row; int n_col; unsigned long long int seed; }; template -::std::ostream& operator<<(::std::ostream& os, const TranposeInputs& dims) -{ - return os; -} - -template -class TransposeTest : public ::testing::TestWithParam> { +class TransposeTest : public ::testing::TestWithParam> { public: TransposeTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), - data(params.len, stream), - data_trans_ref(params.len, stream), - data_trans(params.len, stream) + data(params.n_row * params.n_col, stream), + data_trans_ref(params.n_row * params.n_col, stream), + data_trans(params.n_row * params.n_col, stream) { } protected: void SetUp() override { - int len = params.len; - ASSERT(params.len == 9, "This test works only with len=9!"); - T data_h[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; - raft::update_device(data.data(), data_h, len, stream); - T data_ref_h[] = {1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0}; - raft::update_device(data_trans_ref.data(), data_ref_h, len, stream); + int len = params.n_row * params.n_col; + std::vector data_h(len); + std::vector data_ref_h(len); + + initialize_array(data_h.data(), len); + + cpu_transpose_col_major(data_h.data(), data_ref_h.data(), params.n_row, params.n_col); + + raft::update_device(data.data(), data_h.data(), len, stream); + raft::update_device(data_trans_ref.data(), data_ref_h.data(), len, stream); transpose(handle, data.data(), data_trans.data(), params.n_row, params.n_col, stream); - transpose(data.data(), params.n_row, stream); + if (params.n_row == params.n_col) { transpose(data.data(), params.n_col, stream); } resource::sync_stream(handle, stream); } @@ -78,28 +134,45 @@ class TransposeTest : public ::testing::TestWithParam> { raft::resources handle; cudaStream_t stream; - TranposeInputs params; + TransposeInputs params; rmm::device_uvector data, data_trans, data_trans_ref; }; -const std::vector> inputsf2 = {{0.1f, 3 * 3, 3, 3, 1234ULL}}; - -const std::vector> inputsd2 = {{0.1, 3 * 3, 3, 3, 1234ULL}}; - -const std::vector> inputsh2 = {{0.1, 3 * 3, 3, 3, 1234ULL}}; +const std::vector> inputsf2 = {{0.1f, 3, 3, 1234ULL}, + {0.1f, 3, 4, 1234ULL}, + {0.1f, 300, 300, 1234ULL}, + {0.1f, 300, 4100, 1234ULL}, + {0.1f, 1, 13000, 1234ULL}, + {0.1f, 3, 130001, 1234ULL}}; + +const std::vector> inputsd2 = {{0.1f, 3, 3, 1234ULL}, + {0.1f, 3, 4, 1234ULL}, + {0.1f, 300, 300, 1234ULL}, + {0.1f, 300, 4100, 1234ULL}, + {0.1f, 1, 13000, 1234ULL}, + {0.1f, 3, 130001, 1234ULL}}; + +const std::vector> inputsh2 = {{0.1f, 3, 3, 1234ULL}, + {0.1f, 3, 4, 1234ULL}, + {0.1f, 300, 300, 1234ULL}, + {0.1f, 300, 4100, 1234ULL}, + {0.1f, 1, 13000, 1234ULL}, + {0.1f, 3, 130001, 1234ULL}}; typedef TransposeTest TransposeTestValF; TEST_P(TransposeTestValF, Result) { ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), data_trans.data(), - params.len, + params.n_row * params.n_col, raft::CompareApproxAbs(params.tolerance))); - ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), - data.data(), - params.len, - raft::CompareApproxAbs(params.tolerance))); + if (params.n_row == params.n_col) { + ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), + data.data(), + params.n_row * params.n_col, + raft::CompareApproxAbs(params.tolerance))); + } } typedef TransposeTest TransposeTestValD; @@ -107,59 +180,47 @@ TEST_P(TransposeTestValD, Result) { ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), data_trans.data(), - params.len, - raft::CompareApproxAbs(params.tolerance))); - - ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), - data.data(), - params.len, + params.n_row * params.n_col, raft::CompareApproxAbs(params.tolerance))); -} - -bool validate_half(const half* h_ref, const half* h_result, half tolerance, int len) -{ - bool success = true; - for (int i = 0; i < len; ++i) { - if (raft::abs(__half2float(h_result[i]) - __half2float(h_ref[i])) >= __half2float(tolerance)) { - success = false; - break; - } - if (!success) break; + if (params.n_row == params.n_col) { + ASSERT_TRUE(raft::devArrMatch(data_trans_ref.data(), + data.data(), + params.n_row * params.n_col, + raft::CompareApproxAbs(params.tolerance))); } - return success; } typedef TransposeTest TransposeTestValH; TEST_P(TransposeTestValH, Result) { - half data_trans_ref_h[params.len]; - half data_trans_h[params.len]; - half data_h[params.len]; + auto len = params.n_row * params.n_col; - RAFT_CUDA_TRY(cudaMemcpyAsync(data_trans_ref_h, - data_trans_ref.data(), - params.len * sizeof(half), - cudaMemcpyDeviceToHost, - stream)); - - RAFT_CUDA_TRY(cudaMemcpyAsync( - data_trans_h, data_trans.data(), params.len * sizeof(half), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - data_h, data.data(), params.len * sizeof(half), cudaMemcpyDeviceToHost, stream)); + std::vector data_trans_ref_h(len); + std::vector data_trans_h(len); + std::vector data_h(len); + raft::copy( + data_trans_ref_h.data(), data_trans_ref.data(), len, resource::get_cuda_stream(handle)); + raft::copy(data_trans_h.data(), data_trans.data(), len, resource::get_cuda_stream(handle)); + raft::copy(data_h.data(), data.data(), len, resource::get_cuda_stream(handle)); resource::sync_stream(handle, stream); - ASSERT_TRUE(validate_half(data_trans_ref_h, data_trans_h, params.tolerance, params.len)); - ASSERT_TRUE(validate_half(data_trans_ref_h, data_h, params.tolerance, params.len)); + ASSERT_TRUE(validate_half( + data_trans_ref_h.data(), data_trans_h.data(), params.tolerance, params.n_row * params.n_col)); + + if (params.n_row == params.n_col) { + ASSERT_TRUE(validate_half( + data_trans_ref_h.data(), data_h.data(), params.tolerance, params.n_row * params.n_col)); + } } INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValF, ::testing::ValuesIn(inputsf2)); - INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValD, ::testing::ValuesIn(inputsd2)); - INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValH, ::testing::ValuesIn(inputsh2)); +} // namespace transpose_regular_test + +namespace transpose_extra_test { -namespace { /** * We hide these functions in tests for now until we have a heterogeneous mdarray * implementation. @@ -230,79 +291,225 @@ template } } +template +struct TransposeMdspanInputs { + int n_row; + int n_col; + T tolerance = T{0.01}; +}; + template -void test_transpose_with_mdspan() +void test_transpose_with_mdspan(const TransposeMdspanInputs& param) { + auto len = param.n_row * param.n_col; + std::vector in_h(len); + std::vector out_ref_h(len); + + initialize_array(in_h.data(), len); + raft::resources handle; - auto v = make_device_matrix(handle, 32, 3); - T k{0}; - for (size_t i = 0; i < v.extent(0); ++i) { - for (size_t j = 0; j < v.extent(1); ++j) { - v(i, j) = k++; - } + auto stream = resource::get_cuda_stream(handle); + auto in = make_device_matrix(handle, param.n_row, param.n_col); + auto out_ref = make_device_matrix(handle, param.n_row, param.n_col); + resource::sync_stream(handle, stream); + if constexpr (std::is_same_v) { + cpu_transpose_row_major(in_h.data(), out_ref_h.data(), param.n_row, param.n_col); + } else { + cpu_transpose_col_major(in_h.data(), out_ref_h.data(), param.n_row, param.n_col); } - auto out = transpose(handle, v.view()); - static_assert(std::is_same_v); - ASSERT_EQ(out.extent(0), v.extent(1)); - ASSERT_EQ(out.extent(1), v.extent(0)); + raft::copy(in.data_handle(), in_h.data(), len, resource::get_cuda_stream(handle)); + raft::copy(out_ref.data_handle(), out_ref_h.data(), len, resource::get_cuda_stream(handle)); - k = 0; - for (size_t i = 0; i < out.extent(1); ++i) { - for (size_t j = 0; j < out.extent(0); ++j) { - ASSERT_EQ(out(j, i), k++); - } + auto out = transpose(handle, in.view()); + static_assert(std::is_same_v); + ASSERT_EQ(out.extent(0), in.extent(1)); + ASSERT_EQ(out.extent(1), in.extent(0)); + if constexpr (std::is_same_v) { + std::vector out_h(len); + raft::copy(out_h.data(), out.data_handle(), len, resource::get_cuda_stream(handle)); + ASSERT_TRUE(validate_half(out_ref_h.data(), out_h.data(), param.tolerance, len)); + } else { + ASSERT_TRUE(raft::devArrMatch( + out_ref.data_handle(), out.data_handle(), len, raft::CompareApproxAbs(param.tolerance))); } } -} // namespace -TEST(TransposeTest, MDSpan) +const std::vector> inputs_mdspan_f = {{3, 3}, + {3, 4}, + {300, 300}, + {300, 4100}, + {1, 13000}, + {3, 130001}, + {4100, 300}, + {13000, 1}, + {130001, 3}}; +const std::vector> inputs_mdspan_d = {{3, 3}, + {3, 4}, + {300, 300}, + {300, 4100}, + {1, 13000}, + {3, 130001}, + {4100, 300}, + {13000, 1}, + {130001, 3}}; +const std::vector> inputs_mdspan_h = {{3, 3}, + {3, 4}, + {300, 300}, + {300, 4100}, + {1, 13000}, + {3, 130001}, + {4100, 300}, + {13000, 1}, + {130001, 3}}; + +TEST(TransposeTest, MDSpanFloat) { - test_transpose_with_mdspan(); - test_transpose_with_mdspan(); - - test_transpose_with_mdspan(); - test_transpose_with_mdspan(); + for (const auto& p : inputs_mdspan_f) { + test_transpose_with_mdspan(p); + test_transpose_with_mdspan(p); + } +} +TEST(TransposeTest, MDSpanDouble) +{ + for (const auto& p : inputs_mdspan_d) { + test_transpose_with_mdspan(p); + test_transpose_with_mdspan(p); + } +} +TEST(TransposeTest, MDSpanHalf) +{ + for (const auto& p : inputs_mdspan_h) { + test_transpose_with_mdspan(p); + test_transpose_with_mdspan(p); + } } -namespace { +template +struct TransposeSubmatrixInputs { + int n_row; + int n_col; + int row_beg; + int row_end; + int col_beg; + int col_end; + T tolerance = T{0.01}; +}; + template -void test_transpose_submatrix() +void test_transpose_submatrix(const TransposeSubmatrixInputs& param) { + auto len = param.n_row * param.n_col; + auto sub_len = (param.row_end - param.row_beg) * (param.col_end - param.col_beg); + + std::vector in_h(len); + std::vector out_ref_h(sub_len); + + initialize_array(in_h.data(), len); + raft::resources handle; - auto v = make_device_matrix(handle, 32, 33); - T k{0}; - size_t row_beg{3}, row_end{13}, col_beg{2}, col_end{11}; - for (size_t i = row_beg; i < row_end; ++i) { - for (size_t j = col_beg; j < col_end; ++j) { - v(i, j) = k++; - } + auto stream = resource::get_cuda_stream(handle); + + auto in = make_device_matrix(handle, param.n_row, param.n_col); + auto out_ref = make_device_matrix( + handle, (param.row_end - param.row_beg), (param.col_end - param.col_beg)); + + if constexpr (std::is_same_v) { + auto offset = param.row_beg * param.n_col + param.col_beg; + cpu_transpose_row_major(in_h.data() + offset, + out_ref_h.data(), + (param.row_end - param.row_beg), + (param.col_end - param.col_beg), + in.extent(1), + (param.row_end - param.row_beg)); + } else { + auto offset = param.col_beg * param.n_row + param.row_beg; + cpu_transpose_col_major(in_h.data() + offset, + out_ref_h.data(), + (param.row_end - param.row_beg), + (param.col_end - param.col_beg), + in.extent(0), + (param.col_end - param.col_beg)); } - auto vv = v.view(); - auto submat = std::experimental::submdspan( - vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); - static_assert(std::is_same_v); + raft::copy(in.data_handle(), in_h.data(), len, resource::get_cuda_stream(handle)); + raft::copy(out_ref.data_handle(), out_ref_h.data(), sub_len, resource::get_cuda_stream(handle)); + resource::sync_stream(handle, stream); - auto out = transpose(handle, submat); - ASSERT_EQ(out.extent(0), submat.extent(1)); - ASSERT_EQ(out.extent(1), submat.extent(0)); + auto in_submat = std::experimental::submdspan(in.view(), + std::make_tuple(param.row_beg, param.row_end), + std::make_tuple(param.col_beg, param.col_end)); - k = 0; - for (size_t i = 0; i < out.extent(1); ++i) { - for (size_t j = 0; j < out.extent(0); ++j) { - ASSERT_EQ(out(j, i), k++); - } + static_assert(std::is_same_v); + auto out = transpose(handle, in_submat); + + ASSERT_EQ(out.extent(0), in_submat.extent(1)); + ASSERT_EQ(out.extent(1), in_submat.extent(0)); + + if constexpr (std::is_same_v) { + std::vector out_h(sub_len); + + raft::copy(out_h.data(), out.data_handle(), sub_len, resource::get_cuda_stream(handle)); + ASSERT_TRUE(validate_half(out_ref_h.data(), out_h.data(), param.tolerance, sub_len)); + } else { + ASSERT_TRUE(raft::devArrMatch(out_ref.data_handle(), + out.data_handle(), + sub_len, + raft::CompareApproxAbs(param.tolerance))); } } -} // namespace - -TEST(TransposeTest, SubMatrix) +const std::vector> inputs_submatrix_f = { + {3, 3, 1, 2, 0, 2}, + {3, 4, 1, 3, 2, 3}, + {300, 300, 1, 299, 2, 239}, + {300, 4100, 3, 299, 101, 4001}, + {2, 13000, 0, 1, 3, 13000}, + {3, 130001, 0, 3, 3999, 129999}, + {4100, 300, 159, 4001, 125, 300}, + {13000, 5, 0, 11111, 0, 3}, + {130001, 3, 19, 130000, 2, 3}}; +const std::vector> inputs_submatrix_d = { + {3, 3, 1, 2, 0, 2}, + {3, 4, 1, 3, 2, 3}, + {300, 300, 1, 299, 2, 239}, + {300, 4100, 3, 299, 101, 4001}, + {2, 13000, 0, 1, 3, 13000}, + {3, 130001, 0, 3, 3999, 129999}, + {4100, 300, 159, 4001, 125, 300}, + {13000, 5, 0, 11111, 0, 3}, + {130001, 3, 19, 130000, 2, 3}}; +const std::vector> inputs_submatrix_h = { + {3, 3, 1, 2, 0, 2}, + {3, 4, 1, 3, 2, 3}, + {300, 300, 1, 299, 2, 239}, + {300, 4100, 3, 299, 101, 4001}, + {2, 13000, 0, 1, 3, 13000}, + {3, 130001, 0, 3, 3999, 129999}, + {4100, 300, 159, 4001, 125, 300}, + {13000, 5, 0, 11111, 0, 3}, + {130001, 3, 19, 130000, 2, 3}}; + +TEST(TransposeTest, SubMatrixFloat) { - test_transpose_submatrix(); - test_transpose_submatrix(); - - test_transpose_submatrix(); - test_transpose_submatrix(); + for (const auto& p : inputs_submatrix_f) { + test_transpose_submatrix(p); + test_transpose_submatrix(p); + } } +TEST(TransposeTest, SubMatrixDouble) +{ + for (const auto& p : inputs_submatrix_d) { + test_transpose_submatrix(p); + test_transpose_submatrix(p); + } +} +TEST(TransposeTest, SubMatrixHalf) +{ + for (const auto& p : inputs_submatrix_h) { + test_transpose_submatrix(p); + test_transpose_submatrix(p); + } +} + +} // namespace transpose_extra_test } // end namespace linalg } // end namespace raft