Skip to content

Commit

Permalink
Merge pull request rapidsai#1858 from rapidsai/branch-23.10
Browse files Browse the repository at this point in the history
Forward-merge branch-23.10 to branch-23.12
  • Loading branch information
GPUtester authored Sep 27, 2023
2 parents ba1460c + 25858c5 commit 2f08bdd
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 52 deletions.
32 changes: 20 additions & 12 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ include(rapids-find)

option(BUILD_CPU_ONLY "Build CPU only components. Applies to RAFT ANN benchmarks currently" OFF)

# workaround for rapids_cuda_init_architectures not working for arch detection with enable_language(CUDA)
# workaround for rapids_cuda_init_architectures not working for arch detection with
# enable_language(CUDA)
set(lang_list "CXX")

if(NOT BUILD_CPU_ONLY)
Expand Down Expand Up @@ -286,7 +287,8 @@ endif()
set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled)

if(RAFT_COMPILE_LIBRARY)
add_library(raft_objs OBJECT
add_library(
raft_objs OBJECT
src/core/logger.cpp
src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
Expand Down Expand Up @@ -331,6 +333,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu
src/neighbors/brute_force_knn_int_float_int.cu
src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu
src/neighbors/brute_force_knn_index_float.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
Expand Down Expand Up @@ -452,18 +455,21 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
)
)
set_target_properties(
raft_objs
PROPERTIES CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
POSITION_INDEPENDENT_CODE ON
)

target_compile_definitions(raft_objs PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY")
target_compile_options(raft_objs PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>")
target_compile_options(
raft_objs PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>"
)

add_library(raft_lib SHARED $<TARGET_OBJECTS:raft_objs>)
add_library(raft_lib_static STATIC $<TARGET_OBJECTS:raft_objs>)
Expand All @@ -477,13 +483,15 @@ if(RAFT_COMPILE_LIBRARY)
)

foreach(target raft_lib raft_lib_static raft_objs)
target_link_libraries(${target} PUBLIC
raft::raft
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>)
target_link_libraries(
${target}
PUBLIC raft::raft
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
)

#So consumers know when using libraft.so/libraft.a
# So consumers know when using libraft.so/libraft.a
target_compile_definitions(${target} PUBLIC "RAFT_COMPILED")
# ensure CUDA symbols aren't relocated to the middle of the debug build binaries
target_link_options(${target} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld")
Expand Down
39 changes: 38 additions & 1 deletion cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/core/resources.hpp> // raft::resources
#include <raft/distance/distance_types.hpp> // raft::distance::DistanceType
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <raft/neighbors/brute_force_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

Expand All @@ -38,6 +39,19 @@ inline void knn_merge_parts(
size_t n_samples,
std::optional<raft::device_vector_view<idx_t, idx_t>> translations = std::nullopt) RAFT_EXPLICIT;

template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
typename matrix_idx,
Expand Down Expand Up @@ -93,6 +107,29 @@ instantiate_raft_neighbors_brute_force_knn(

#undef instantiate_raft_neighbors_brute_force_knn

namespace raft::neighbors::brute_force {

extern template void search<float, int>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);
} // namespace raft::neighbors::brute_force

#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
value_t, idx_t, idx_layout, query_layout) \
extern template void raft::neighbors::brute_force::fused_l2_knn( \
Expand Down
98 changes: 97 additions & 1 deletion cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force_types.hpp>
#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>

Expand Down Expand Up @@ -280,6 +281,101 @@ void fused_l2_knn(raft::resources const& handle,
metric);
}

/** @} */ // end group brute_force_knn
/**
* @brief Build the index from the dataset for efficient search.
*
* @tparam T data element type
*
* @param[in] res
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[in] metric: distance metric to use. Euclidean (L2) is used by default
* @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
*
* @return the constructed brute force index
*/
template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0)
{
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<device_vector<T, int64_t>> norms;
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
norms = make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
}
}

return index<T>(res, dataset, std::move(norms), metric, metric_arg);
}

/**
* @brief Brute Force search using the constructed index.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] idx brute force index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");

auto k = neighbors.extent(1);
auto d = idx.dataset().extent(1);

std::vector<T*> dataset = {const_cast<T*>(idx.dataset().data_handle())};
std::vector<int64_t> sizes = {idx.dataset().extent(0)};
std::vector<T*> norms;
if (idx.has_norms()) { norms.push_back(const_cast<T*>(idx.norms().data_handle())); }

detail::brute_force_knn_impl<int64_t, IdxT, T>(res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr);
}
/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
144 changes: 144 additions & 0 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "ann_types.hpp"
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>

#include <raft/core/logger.hpp>

namespace raft::neighbors::brute_force {
/**
* @addtogroup brute_force
* @{
*/

/**
* @brief Brute Force index.
*
* The index stores the dataset and norms for the dataset in device memory.
*
* @tparam T data element type
*/
template <typename T>
struct index : ann::index {
public:
/** Distance metric used for retrieval */
[[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept
{
return metric_;
}

/** Total length of the index (number of vectors). */
[[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); }

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); }

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept
-> device_matrix_view<const T, int64_t, row_major>
{
return dataset_view_;
}

/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
index(index&&) = default;
auto operator=(const index&) -> index& = delete;
auto operator=(index&&) -> index& = default;
~index() = default;

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* If the dataset is already in GPU memory, then this class stores a non-owning reference to
* the dataset. If the dataset is in host memory, it will be copied to the device and the
* index will own the device memory.
*/
template <typename data_accessor>
index(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset,
std::optional<raft::device_vector<T, int64_t>>&& norms,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
update_dataset(res, dataset);
resource::sync_stream(res);
}

private:
/**
* Replace the dataset with a new dataset.
*/
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
dataset_view_ = dataset;
}

/**
* Replace the dataset with a new dataset.
*
* We create a copy of the dataset on the device. The index manages the lifetime of this copy.
*/
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
dataset_ = make_device_matrix<T, int64_t>(dataset.extents(0), dataset.extents(1));
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
dataset.size(),
resource::get_cuda_stream(res));
dataset_view_ = make_const_mdspan(dataset_.view());
}

raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};

/** @} */

} // namespace raft::neighbors::brute_force
Loading

0 comments on commit 2f08bdd

Please sign in to comment.