Skip to content

Commit

Permalink
Changing RAFT_EXPLICT_* to CUVS_EXPLITI_* (#141)
Browse files Browse the repository at this point in the history
Changing this for now. We'll work on removing this macro in a future release.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Divye Gala (https://github.com/divyegala)

URL: #141
  • Loading branch information
cjnolet authored May 23, 2024
1 parent ac85fa6 commit 3515f44
Show file tree
Hide file tree
Showing 25 changed files with 2,353 additions and 31 deletions.
9 changes: 4 additions & 5 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,8 @@ if(NOT BUILD_CPU_ONLY)
# Keep cuVS as lightweight as possible. Only CUDA libs and rmm should be used in global target.
target_link_libraries(
cuvs
PUBLIC rmm::rmm raft::raft
PRIVATE nvidia::cutlass::cutlass ${CUVS_CTK_MATH_DEPENDENCIES}
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
PUBLIC rmm::rmm raft::raft ${CUVS_CTK_MATH_DEPENDENCIES}
PRIVATE nvidia::cutlass::cutlass $<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
)
endif()

Expand Down Expand Up @@ -573,8 +572,8 @@ if(BUILD_C_LIBRARY)

target_link_libraries(
cuvs_c
PUBLIC cuvs::cuvs
PRIVATE raft::raft ${CUVS_CTK_MATH_DEPENDENCIES}
PUBLIC cuvs::cuvs ${CUVS_CTK_MATH_DEPENDENCIES}
PRIVATE raft::raft
)

# ensure CUDA symbols aren't relocated to the middle of the debug build binaries
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/cluster/detail/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#pragma once

#include "../../sparse/neighbors/cross_component_nn.cuh"
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/sparse/neighbors/cross_component_nn.cuh>
#include <raft/sparse/op/sort.cuh>
#include <raft/sparse/solver/mst.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -86,7 +86,7 @@ void connect_knn_graph(
static constexpr size_t default_row_batch_size = 4096;
static constexpr size_t default_col_batch_size = 16;

raft::sparse::neighbors::cross_component_nn<value_idx, value_t>(handle,
cuvs::sparse::neighbors::cross_component_nn<value_idx, value_t>(handle,
connected_edges,
X,
color,
Expand Down Expand Up @@ -166,14 +166,14 @@ void build_sorted_mst(
handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, true);

int iters = 1;
int n_components = raft::sparse::neighbors::get_n_components(color, m, stream);
int n_components = cuvs::sparse::neighbors::get_n_components(color, m, stream);

while (n_components > 1 && iters < max_iter) {
connect_knn_graph<value_idx, value_t>(handle, X, mst_coo, m, n, color, reduction_op);

iters++;

n_components = raft::sparse::neighbors::get_n_components(color, m, stream);
n_components = cuvs::sparse::neighbors::get_n_components(color, m, stream);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void single_linkage(raft::resources const& handle,
* 2. Construct MST, sorted by weights
*/
rmm::device_uvector<value_idx> color(m, stream);
raft::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);
cuvs::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);
detail::build_sorted_mst<value_idx, value_t>(handle,
X,
indptr.data(),
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/compress_to_bits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -100,7 +101,7 @@ void compress_to_bits(raft::resources const& handle,
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in,
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out)
{
auto stream = resource::get_cuda_stream(handle);
auto stream = raft::resource::get_cuda_stream(handle);
constexpr int bits_per_element = 8 * sizeof(T);

RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0),
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/detail/masked_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/
#pragma once
#include "../pairwise_distance_base.cuh"
#include "pairwise_distance_base.cuh"
#include <raft/linalg/contractions.cuh>
#include <raft/util/cuda_utils.cuh>

Expand Down
12 changes: 6 additions & 6 deletions cpp/src/distance/detail/masked_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#pragma once

#include "../compress_to_bits.cuh"
#include "../fused_distance_nn/fused_l2_nn.cuh"
#include "../masked_distance_base.cuh"
#include "compress_to_bits.cuh"
#include "fused_distance_nn/fused_l2_nn.cuh"
#include "masked_distance_base.cuh"
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/linalg/contractions.cuh>
Expand Down Expand Up @@ -251,14 +251,14 @@ void masked_l2_nn_impl(raft::resources const& handle,
bool sqrt,
bool initOutBuffer)
{
typedef typename linalg::Policy4x4<DataT, 1>::Policy P;
typedef typename raft::linalg::Policy4x4<DataT, 1>::Policy P;

static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block.");

// Get stream and workspace memory resource
rmm::mr::device_memory_resource* ws_mr =
dynamic_cast<rmm::mr::device_memory_resource*>(resource::get_workspace_resource(handle));
auto stream = resource::get_cuda_stream(handle);
dynamic_cast<rmm::mr::device_memory_resource*>(raft::resource::get_workspace_resource(handle));
auto stream = raft::resource::get_cuda_stream(handle);

// Acquire temporary buffers and initialize to zero:
// 1) Adjacency matrix bitfield
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

namespace cuvs::distance::detail {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/detail/pairwise_matrix/dispatch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifndef CUVS_EXPLICIT_INSTANTIATE_ONLY
#include "dispatch-inl.cuh"
#endif

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#include <rmm/device_uvector.hpp> // rmm::device_uvector

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

namespace cuvs {
namespace distance {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifndef CUVS_EXPLICIT_INSTANTIATE_ONLY
#include "distance-inl.cuh"
#endif

Expand Down
201 changes: 201 additions & 0 deletions cpp/src/distance/masked_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __MASKED_L2_NN_H
#define __MASKED_L2_NN_H

#pragma once

#include "detail/masked_nn.cuh"
#include "fused_distance_nn_helpers.cuh"
#include <raft/core/handle.hpp>
#include <raft/util/cuda_utils.cuh>

#include <stdint.h>

#include <limits>

namespace cuvs {
namespace distance {
/**
* \defgroup masked_nn Masked 1-nearest neighbors
* @{
*/

/**
* @brief Parameter struct for masked_l2_nn function
*
* @tparam ReduceOpT Type of reduction operator in the epilogue.
* @tparam KVPReduceOpT Type of Reduction operation on key value pairs.
*
* Usage example:
* @code{.cpp}
* #include <cuvs/distance/masked_nn.cuh>
*
* using IdxT = int;
* using DataT = float;
* using RedOpT = cuvs::distance::MinAndDistanceReduceOp<IdxT, DataT>;
* using PairRedOpT = cuvs::distance::KVPMinReduce<IdxT, DataT>;
* using ParamT = cuvs::distance::masked_l2_nn_params<RedOpT, PairRedOpT>;
*
* bool init_out = true;
* bool sqrt = false;
*
* ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out};
* @endcode
*
* Prescribes how to reduce a distance to an intermediate type (`redOp`), and
* how to reduce two intermediate types (`pairRedOp`). Typically, a distance is
* mapped to an (index, value) pair and (index, value) pair with the lowest
* value (distance) is selected.
*
* In addition, prescribes whether to compute the square root of the distance
* (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`).
*/
template <typename ReduceOpT, typename KVPReduceOpT>
struct masked_l2_nn_params {
/** Reduction operator in the epilogue */
ReduceOpT redOp;
/** Reduction operation on key value pairs */
KVPReduceOpT pairRedOp;
/** Whether the output `minDist` should contain L2-sqrt */
bool sqrt;
/** Whether to initialize the output buffer before the main kernel launch */
bool initOutBuffer;
};

/**
* @brief Masked L2 distance and 1-nearest-neighbor computation in a single call.
*
* This function enables faster computation of nearest neighbors if the
* computation of distances between certain point pairs can be skipped.
*
* We use an adjacency matrix that describes which distances to calculate. The
* points in `y` are divided into groups, and the adjacency matrix indicates
* whether to compute distances between points in `x` and groups in `y`. In other
* words, if `adj[i,k]` is true then distance between point `x_i`, and points in
* `group_k` will be calculated.
*
* **Performance considerations**
*
* The points in `x` are processed in tiles of `M` points (`M` is currently 64,
* but may change in the future). As a result, the largest compute time
* reduction occurs if all `M` points can skip a group. If only part of the `M`
* points can skip a group, then at most a minor compute time reduction and a
* modest energy use reduction can be expected.
*
* The points in `y` are also grouped into tiles of `N` points (`N` is currently
* 64, but may change in the future). As a result, group sizes should be larger
* than `N` to avoid wasting computational resources. If the group sizes are
* evenly divisible by `N`, then the computation is most efficient, although for
* larger group sizes this effect is minor.
*
*
* **Comparison to SDDM**
*
* [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense
* matrix multiplication) is a matrix-matrix multiplication where only part of
* the output is computed. Compared to masked_l2_nn, there are a few differences:
*
* - The output of masked_l2_nn is a single vector (of nearest neighbors) and not
* a sparse matrix.
*
* - The sampling in masked_l2_nn is expressed through intermediate "groups"
rather than a CSR format.
*
* @tparam DataT data type
* @tparam OutT output type to either store 1-NN indices and their minimum
* distances or store only the min distances. Accordingly, one
* has to pass an appropriate `ReduceOpT`
* @tparam IdxT indexing arithmetic type
* @tparam ReduceOpT A struct to perform the final needed reduction operation
* and also to initialize the output array elements with the
* appropriate initial value needed for reduction.
*
* @param handle RAFT handle for managing expensive resources
* @param params Parameter struct specifying the reduction operations.
* @param[in] x First matrix. Row major. Dim = `m x k`.
* (on device).
* @param[in] y Second matrix. Row major. Dim = `n x k`.
* (on device).
* @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device).
* @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device)
* @param[in] adj A boolean adjacency matrix indicating for each
* row of `x` and each group in `y` whether to compute the
* distance. Dim = `m x num_groups`.
* @param[in] group_idxs An array containing the *end* indices of each group
* in `y`. The value of group_idxs[j] indicates the
* start of group j + 1, i.e., it is the inclusive
* scan of the group lengths. The first group is
* always assumed to start at index 0 and the last
* group typically ends at index `n`. Length =
* `num_groups`.
* @param[out] out will contain the reduced output (Length = `m`)
* (on device)
*/
template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT, typename KVPReduceOpT>
void masked_l2_nn(raft::resources const& handle,
cuvs::distance::masked_l2_nn_params<ReduceOpT, KVPReduceOpT> params,
raft::device_matrix_view<const DataT, IdxT, raft::layout_c_contiguous> x,
raft::device_matrix_view<const DataT, IdxT, raft::layout_c_contiguous> y,
raft::device_vector_view<const DataT, IdxT, raft::layout_c_contiguous> x_norm,
raft::device_vector_view<const DataT, IdxT, raft::layout_c_contiguous> y_norm,
raft::device_matrix_view<const bool, IdxT, raft::layout_c_contiguous> adj,
raft::device_vector_view<const IdxT, IdxT, raft::layout_c_contiguous> group_idxs,
raft::device_vector_view<OutT, IdxT, raft::layout_c_contiguous> out)
{
IdxT m = x.extent(0);
IdxT n = y.extent(0);
IdxT k = x.extent(1);
IdxT num_groups = group_idxs.extent(0);

// Match k dimension of x, y
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal.");
// Match x, x_norm and y, y_norm
RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`.");
RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` ");
// Match adj to x and group_idxs
RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`.");
RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`.");
// NOTE: We do not check if all indices in group_idxs actually points *inside* y.

// If there is no work to be done, return immediately.
if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; }

detail::masked_l2_nn_impl<DataT, OutT, IdxT, ReduceOpT>(handle,
out.data_handle(),
x.data_handle(),
y.data_handle(),
x_norm.data_handle(),
y_norm.data_handle(),
adj.data_handle(),
group_idxs.data_handle(),
num_groups,
m,
n,
k,
params.redOp,
params.pairRedOp,
params.sqrt,
params.initOutBuffer);
}

/** @} */

} // namespace distance
} // namespace cuvs

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace cuvs::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifndef CUVS_EXPLICIT_INSTANTIATE_ONLY
#include "search_multi_cta_kernel-inl.cuh"
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace cuvs::neighbors::cagra::detail {
namespace single_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#ifndef CUVS_EXPLICIT_INSTANTIATE_ONLY
#include "search_single_cta_kernel-inl.cuh"
#endif

Expand Down
Loading

0 comments on commit 3515f44

Please sign in to comment.