Skip to content

Commit

Permalink
fix doc bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mdoijade committed Jan 31, 2024
1 parent 6d69589 commit 904fdd6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 33 deletions.
16 changes: 3 additions & 13 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ endfunction()

if(RAFT_ANN_BENCH_USE_HNSWLIB)
ConfigureAnnBench(
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp
LINKS
hnswlib::hnswlib
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib
)

endif()
Expand Down Expand Up @@ -276,12 +274,7 @@ endif()

if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
ConfigureAnnBench(
NAME
RAFT_CAGRA_HNSWLIB
PATH
bench/ann/src/raft/raft_cagra_hnswlib.cu
LINKS
raft::compiled
NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled
hnswlib::hnswlib
)
endif()
Expand Down Expand Up @@ -336,10 +329,7 @@ endif()

if(RAFT_ANN_BENCH_USE_GGNN)
include(cmake/thirdparty/get_glog.cmake)
ConfigureAnnBench(
NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu
LINKS glog::glog ggnn::ggnn
)
ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn)
endif()

# ##################################################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,6 @@ void cutlassFusedDistanceNN(const DataT* x,
int totalTiles = columnTiles * rowTiles;
int thread_blocks =
rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles;
printf(
"totalTiles = %d full_wave = %d thread_blocks = %d rowTiles = %d mmaShapeM = %d mmaShapeN = "
"%d\n",
totalTiles,
full_wave,
thread_blocks,
rowTiles,
mmaShapeM,
mmaShapeN);

typename fusedDistanceNN::Arguments arguments{
problem_size,
Expand Down
27 changes: 16 additions & 11 deletions cpp/include/raft/distance/fused_distance_nn-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ namespace distance {
* traffic on this intermediate buffer, otherwise needed during the reduction
* phase for 1-NN.
*
* @tparam DataT data type
* @tparam OutT output type to either store 1-NN indices and their minimum
* distances or store only the min distances. Accordingly, one
* has to pass an appropriate `ReduceOpT`
* @tparam IdxT indexing arithmetic type
* @tparam ReduceOpT A struct to perform the final needed reduction operation
* and also to initialize the output array elements with the
* appropriate initial value needed for reduction.
* @tparam DataT data type
* @tparam OutT output type to either store 1-NN indices and their minimum
* distances or store only the min distances. Accordingly, one
* has to pass an appropriate `ReduceOpT`
* @tparam IdxT indexing arithmetic type
* @tparam ReduceOpT A struct to perform the final needed reduction operation
* and also to initialize the output array elements with the
* appropriate initial value needed for reduction.
* @tparam KVPReduceOpT A struct providing functions for key-value pair comparison.
*
* @param[out] min will contain the reduced output (Length = `m`)
* (on device)
Expand All @@ -66,10 +67,13 @@ namespace distance {
* @param[in] k gemm k
* @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device)
* @param[in] redOp reduction operator in the epilogue
* @param[in] pairRedOp reduction operation on key value pairs
* @param[in] pairRedOp reduction operation on key value pairs
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt
* @param[in] initOutBuffer whether to initialize the output buffer before the
* main kernel launch
* @param[in] isRowMajor whether the input/output is row or column major.
* @param[in] metric Distance metric to be used (supports L2, cosine)
* @param[in] metric_arg power argument for distances like Minkowski (not supported for now)
* @param[in] stream cuda stream
*/
template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT, typename KVPReduceOpT>
Expand Down Expand Up @@ -248,8 +252,6 @@ void fusedDistanceNN(OutT* min,
*
* fusedDistanceNN cannot be compiled in the distance library due to the lambda
* operators, so this wrapper covers the most common case (minimum).
* This should be preferred to the more generic API when possible, in order to
* reduce compilation times for users of the shared library.
*
* @tparam DataT data type
* @tparam OutT output type to either store 1-NN indices and their minimum
Expand All @@ -271,6 +273,9 @@ void fusedDistanceNN(OutT* min,
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt
* @param[in] initOutBuffer whether to initialize the output buffer before the
* main kernel launch
* @param[in] isRowMajor whether the input/output is row or column major.
* @param[in] metric Distance metric to be used (supports L2, cosine)
* @param[in] metric_arg power argument for distances like Minkowski (not supported for now)
* @param[in] stream cuda stream
*/
template <typename DataT, typename OutT, typename IdxT>
Expand Down

0 comments on commit 904fdd6

Please sign in to comment.