Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add identity matrix function #1548

Merged
merged 7 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void norm(raft::resources const& handle,
{
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");

auto constexpr row_major = std::is_same_v<typename decltype(out)::layout_type, raft::row_major>;
auto constexpr row_major = std::is_same_v<LayoutPolicy, raft::row_major>;
auto along_rows = apply == Apply::ALONG_ROWS;

if (along_rows) {
Expand Down
28 changes: 14 additions & 14 deletions cpp/include/raft/matrix/detail/matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,52 +230,52 @@ void copyUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, c
/**
* @brief Copy a vector to the diagonal of a matrix
* @param vec: vector of length k = min(n_rows, n_cols)
* @param matrix: matrix of size n_rows x n_cols
* @param m: number of rows of the matrix
* @param n: number of columns of the matrix
* @param matrix: matrix of size n_rows x n_cols (leading dimension = lda)
* @param lda: leading dimension of the matrix
* @param k: dimensionality
*/
template <typename m_t, typename idx_t = int>
__global__ void copyVectorToMatrixDiagonal(const m_t* vec, m_t* matrix, idx_t m, idx_t n, idx_t k)
__global__ void copyVectorToMatrixDiagonal(const m_t* vec, m_t* matrix, idx_t lda, idx_t k)
{
idx_t idx = threadIdx.x + blockDim.x * blockIdx.x;

if (idx < k) { matrix[idx + idx * m] = vec[idx]; }
if (idx < k) { matrix[idx + idx * lda] = vec[idx]; }
}

/**
* @brief Copy matrix diagonal to vector
* @param vec: vector of length k = min(n_rows, n_cols)
* @param matrix: matrix of size n_rows x n_cols
* @param m: number of rows of the matrix
* @param n: number of columns of the matrix
* @param matrix: matrix of size n_rows x n_cols (leading dimension = lda)
* @param lda: leading dimension of the matrix
* @param k: dimensionality
*/
template <typename m_t, typename idx_t = int>
__global__ void copyVectorFromMatrixDiagonal(m_t* vec, const m_t* matrix, idx_t m, idx_t n, idx_t k)
__global__ void copyVectorFromMatrixDiagonal(m_t* vec, const m_t* matrix, idx_t lda, idx_t k)
{
idx_t idx = threadIdx.x + blockDim.x * blockIdx.x;

if (idx < k) { vec[idx] = matrix[idx + idx * m]; }
if (idx < k) { vec[idx] = matrix[idx + idx * lda]; }
}

template <typename m_t, typename idx_t = int>
void initializeDiagonalMatrix(
const m_t* vec, m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream)
const m_t* vec, m_t* matrix, idx_t n_rows, idx_t n_cols, bool row_major, cudaStream_t stream)
{
idx_t k = std::min(n_rows, n_cols);
idx_t lda = row_major ? n_cols : n_rows;
dim3 block(64);
dim3 grid((k + block.x - 1) / block.x);
copyVectorToMatrixDiagonal<<<grid, block, 0, stream>>>(vec, matrix, n_rows, n_cols, k);
copyVectorToMatrixDiagonal<<<grid, block, 0, stream>>>(vec, matrix, lda, k);
}

template <typename m_t, typename idx_t = int>
void getDiagonalMatrix(m_t* vec, const m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream)
void getDiagonalMatrix(m_t* vec, const m_t* matrix, idx_t n_rows, idx_t n_cols, bool row_major, cudaStream_t stream)
{
idx_t k = std::min(n_rows, n_cols);
idx_t lda = row_major ? n_cols : n_rows;
dim3 block(64);
dim3 grid((k + block.x - 1) / block.x);
copyVectorFromMatrixDiagonal<<<grid, block, 0, stream>>>(vec, matrix, n_rows, n_cols, k);
copyVectorFromMatrixDiagonal<<<grid, block, 0, stream>>>(vec, matrix, lda, k);
}

/**
Expand Down
25 changes: 25 additions & 0 deletions cpp/include/raft/matrix/diagonal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/detail/matrix.cuh>
#include <raft/matrix/init.cuh>
#include <raft/util/input_validation.hpp>

namespace raft::matrix {

Expand All @@ -40,11 +42,13 @@ void set_diagonal(raft::resources const& handle,
{
RAFT_EXPECTS(vec.extent(0) == std::min(matrix.extent(0), matrix.extent(1)),
"Diagonal vector must be min(matrix.n_rows, matrix.n_cols)");
constexpr auto is_row_major = std::is_same_v<layout, layout_c_contiguous>;

detail::initializeDiagonalMatrix(vec.data_handle(),
matrix.data_handle(),
matrix.extent(0),
matrix.extent(1),
is_row_major,
resource::get_cuda_stream(handle));
}

Expand All @@ -61,10 +65,12 @@ void get_diagonal(raft::resources const& handle,
{
RAFT_EXPECTS(vec.extent(0) == std::min(matrix.extent(0), matrix.extent(1)),
"Diagonal vector must be min(matrix.n_rows, matrix.n_cols)");
constexpr auto is_row_major = std::is_same_v<layout, layout_c_contiguous>;
detail::getDiagonalMatrix(vec.data_handle(),
matrix.data_handle(),
matrix.extent(0),
matrix.extent(1),
is_row_major,
resource::get_cuda_stream(handle));
}

Expand All @@ -83,6 +89,25 @@ void invert_diagonal(raft::resources const& handle,
inout.data_handle(), inout.extent(0), resource::get_cuda_stream(handle));
}

/**
* @brief create an identity matrix
* @tparam math_t data-type upon which the math operation will be performed
* @tparam idx_t indexing type used for the output
* @tparam layout_t layout of the matrix data (must be row or col major)
* @param[in] handle: raft handle
* @param[out] out: output matrix
*/
template <typename math_t, typename idx_t, typename layout_t>
void eye(const raft::resources& handle, raft::device_matrix_view<math_t, idx_t, layout_t> out)
{
RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");

auto diag = raft::make_device_vector<math_t, idx_t>(handle, min(out.extent(0), out.extent(1)));
RAFT_CUDA_TRY(cudaMemsetAsync(out.data_handle(), 0, out.size() * sizeof (math_t), resource::get_cuda_stream(handle)));
raft::matrix::fill(handle, diag.view(), math_t(1));
set_diagonal(handle, raft::make_const_mdspan(diag.view()), out);
}

/** @} */ // end of group matrix_diagonal

} // namespace raft::matrix
6 changes: 3 additions & 3 deletions cpp/include/raft/matrix/matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ void copyUpperTriangular(m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, cudaStr
}

/**
* @brief Initialize a diagonal matrix with a vector
* @brief Initialize a diagonal col-major matrix with a vector
* @param vec: vector of length k = min(n_rows, n_cols)
* @param matrix: matrix of size n_rows x n_cols
* @param matrix: matrix of size n_rows x n_cols (col-major)
* @param n_rows: number of rows of the matrix
* @param n_cols: number of columns of the matrix
* @param stream: cuda stream
Expand All @@ -232,7 +232,7 @@ template <typename m_t, typename idx_t = int>
void initializeDiagonalMatrix(
m_t* vec, m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream)
{
detail::initializeDiagonalMatrix(vec, matrix, n_rows, n_cols, stream);
detail::initializeDiagonalMatrix(vec, matrix, n_rows, n_cols, false, stream);
}

/**
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ if(BUILD_TESTS)
test/matrix/columnSort.cu
test/matrix/diagonal.cu
test/matrix/gather.cu
test/matrix/eye.cu
test/matrix/linewise_op.cu
test/matrix/math.cu
test/matrix/matrix.cu
Expand Down
92 changes: 92 additions & 0 deletions cpp/test/matrix/eye.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/matrix/diagonal.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft::matrix {

template <typename T>
struct InitInputs {
int n_row;
int n_col;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const InitInputs<T>& dims)
{
return os;
}

template <typename T>
class InitTest : public ::testing::TestWithParam<InitInputs<T>> {
public:
InitTest()
: params(::testing::TestWithParam<InitInputs<T>>::GetParam()),
stream(resource::get_cuda_stream(handle))
{
}

protected:
void test_eye()
{
ASSERT_TRUE(params.n_row == 4 && params.n_col == 5);
auto eyemat_col =
raft::make_device_matrix<T, int, raft::col_major>(handle, params.n_row, params.n_col);
raft::matrix::eye(handle, eyemat_col.view());
std::vector<T> eye_exp{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0};
std::vector<T> eye_act(params.n_col * params.n_row);
raft::copy(eye_act.data(), eyemat_col.data_handle(), eye_act.size(), stream);
resource::sync_stream(handle, stream);
ASSERT_TRUE(hostVecMatch(eye_exp, eye_act, raft::Compare<T>()));

auto eyemat_row =
raft::make_device_matrix<T, int, raft::row_major>(handle, params.n_row, params.n_col);
raft::matrix::eye(handle, eyemat_row.view());
eye_exp = std::vector<T>{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0};
raft::copy(eye_act.data(), eyemat_row.data_handle(), eye_act.size(), stream);
resource::sync_stream(handle, stream);
ASSERT_TRUE(hostVecMatch(eye_exp, eye_act, raft::Compare<T>()));
}

void SetUp() override { test_eye(); }

protected:
raft::resources handle;
cudaStream_t stream;

InitInputs<T> params;
};

const std::vector<InitInputs<float>> inputsf1 = {{4, 5}};

const std::vector<InitInputs<double>> inputsd1 = {{4, 5}};

typedef InitTest<float> InitTestF;
TEST_P(InitTestF, Result) {}

typedef InitTest<double> InitTestD;
TEST_P(InitTestD, Result) {}

INSTANTIATE_TEST_SUITE_P(InitTests, InitTestF, ::testing::ValuesIn(inputsf1));
INSTANTIATE_TEST_SUITE_P(InitTests, InitTestD, ::testing::ValuesIn(inputsd1));

} // namespace raft::matrix