Skip to content

Commit

Permalink
Add tests for raft::matrix (#937)
Browse files Browse the repository at this point in the history
Linking #877.
Implementation of the remaining tests for `raft::matrix`

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #937
  • Loading branch information
lowener authored Oct 27, 2022
1 parent ced7bce commit b0a7064
Show file tree
Hide file tree
Showing 26 changed files with 1,746 additions and 29 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cstdint>
#include <raft/core/detail/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/mdarray.hpp>
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cstdint>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>

Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/host_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cstdint>
#include <raft/core/host_mdspan.hpp>

#include <raft/core/detail/host_mdarray.hpp>
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/host_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cstdint>
#include <raft/core/mdspan.hpp>

#include <raft/core/host_device_accessor.hpp>
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/mdspan_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ using vector_extent = std::experimental::extents<IndexType, dynamic_extent>;
template <typename IndexType>
using matrix_extent = std::experimental::extents<IndexType, dynamic_extent, dynamic_extent>;

template <typename IndexType = std::uint32_t>
template <typename IndexType>
using scalar_extent = std::experimental::extents<IndexType, 1>;

/**
Expand Down
194 changes: 194 additions & 0 deletions cpp/include/raft/linalg/matrix_vector.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/linalg/linalg_types.hpp>
#include <raft/matrix/detail/math.cuh>
#include <raft/util/input_validation.hpp>

namespace raft::linalg {

/**
* @brief multiply each row or column of matrix with vector, skipping zeros in vector
* @param [in] handle: raft handle for managing library resources
* @param[inout] data: input matrix, results are in-place
* @param[in] vec: input vector
* @param[in] apply whether the broadcast of vector needs to happen along
* the rows of the matrix or columns using enum class raft::linalg::Apply
*/
template <typename math_t, typename idx_t, typename layout_t>
void binary_mult_skip_zero(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply)
{
bool row_major = raft::is_row_major(data);
auto bcast_along_rows = apply == Apply::ALONG_ROWS;

idx_t vec_size = bcast_along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(
vec.extent(0) == vec_size,
"If `bcast_along_rows==true`, vector size must equal number of columns in the matrix."
"If `bcast_along_rows==false`, vector size must equal number of rows in the matrix.");

matrix::detail::matrixVectorBinaryMultSkipZero(data.data_handle(),
vec.data_handle(),
data.extent(0),
data.extent(1),
row_major,
bcast_along_rows,
handle.get_stream());
}

/**
* @brief divide each row or column of matrix with vector
* @param[in] handle: raft handle for managing library resources
* @param[inout] data: input matrix, results are in-place
* @param[in] vec: input vector
* @param[in] apply whether the broadcast of vector needs to happen along
* the rows of the matrix or columns using enum class raft::linalg::Apply
*/
template <typename math_t, typename idx_t, typename layout_t>
void binary_div(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply)
{
bool row_major = raft::is_row_major(data);
auto bcast_along_rows = apply == Apply::ALONG_ROWS;

idx_t vec_size = bcast_along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(
vec.extent(0) == vec_size,
"If `bcast_along_rows==true`, vector size must equal number of columns in the matrix."
"If `bcast_along_rows==false`, vector size must equal number of rows in the matrix.");

matrix::detail::matrixVectorBinaryDiv(data.data_handle(),
vec.data_handle(),
data.extent(0),
data.extent(1),
row_major,
bcast_along_rows,
handle.get_stream());
}

/**
* @brief divide each row or column of matrix with vector, skipping zeros in vector
* @param[in] handle: raft handle for managing library resources
* @param[inout] data: input matrix, results are in-place
* @param[in] vec: input vector
* @param[in] apply whether the broadcast of vector needs to happen along
* the rows of the matrix or columns using enum class raft::linalg::Apply
* @param[in] return_zero: result is zero if true and vector value is below threshold, original
* value if false
*/
template <typename math_t, typename idx_t, typename layout_t>
void binary_div_skip_zero(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply,
bool return_zero = false)
{
bool row_major = raft::is_row_major(data);
auto bcast_along_rows = apply == Apply::ALONG_ROWS;

idx_t vec_size = bcast_along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(
vec.extent(0) == vec_size,
"If `bcast_along_rows==true`, vector size must equal number of columns in the matrix."
"If `bcast_along_rows==false`, vector size must equal number of rows in the matrix.");

matrix::detail::matrixVectorBinaryDivSkipZero(data.data_handle(),
vec.data_handle(),
data.extent(0),
data.extent(1),
row_major,
bcast_along_rows,
handle.get_stream(),
return_zero);
}

/**
* @brief add each row or column of matrix with vector
* @param[in] handle: raft handle for managing library resources
* @param[inout] data: input matrix, results are in-place
* @param[in] vec: input vector
* @param[in] apply whether the broadcast of vector needs to happen along
* the rows of the matrix or columns using enum class raft::linalg::Apply
*/
template <typename math_t, typename idx_t, typename layout_t>
void binary_add(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply)
{
bool row_major = raft::is_row_major(data);
auto bcast_along_rows = apply == Apply::ALONG_ROWS;

idx_t vec_size = bcast_along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(
vec.extent(0) == vec_size,
"If `bcast_along_rows==true`, vector size must equal number of columns in the matrix."
"If `bcast_along_rows==false`, vector size must equal number of rows in the matrix.");

matrix::detail::matrixVectorBinaryAdd(data.data_handle(),
vec.data_handle(),
data.extent(0),
data.extent(1),
row_major,
bcast_along_rows,
handle.get_stream());
}

/**
* @brief subtract each row or column of matrix with vector
* @param[in] handle: raft handle for managing library resources
* @param[inout] data: input matrix, results are in-place
* @param[in] vec: input vector
* @param[in] apply whether the broadcast of vector needs to happen along
* the rows of the matrix or columns using enum class raft::linalg::Apply
*/
template <typename math_t, typename idx_t, typename layout_t>
void binary_sub(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply)
{
bool row_major = raft::is_row_major(data);
auto bcast_along_rows = apply == Apply::ALONG_ROWS;

idx_t vec_size = bcast_along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(
vec.extent(0) == vec_size,
"If `bcast_along_rows==true`, vector size must equal number of columns in the matrix."
"If `bcast_along_rows==false`, vector size must equal number of rows in the matrix.");

matrix::detail::matrixVectorBinarySub(data.data_handle(),
vec.data_handle(),
data.extent(0),
data.extent(1),
row_major,
bcast_along_rows,
handle.get_stream());
}
} // namespace raft::linalg
40 changes: 40 additions & 0 deletions cpp/include/raft/matrix/argmax.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/matrix/detail/math.cuh>

namespace raft::matrix {

/**
* @brief Argmax: find the row idx with maximum value for each column
* @param[in] handle: raft handle
* @param[in] in: input matrix of size (n_rows, n_cols)
* @param[out] out: output vector of size n_cols
*/
template <typename math_t, typename idx_t, typename matrix_idx_t>
void argmax(const raft::handle_t& handle,
raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in,
raft::device_vector_view<idx_t, matrix_idx_t> out)
{
RAFT_EXPECTS(out.extent(0) == in.extent(0),
"Size of output vector must equal number of rows in input matrix.");
detail::argmax(
in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream());
}
} // namespace raft::matrix
Loading

0 comments on commit b0a7064

Please sign in to comment.