Skip to content

Commit

Permalink
Lazy-initialize the dataset_descriptor to avoid its overheads in the …
Browse files Browse the repository at this point in the history
…persistent kernel
  • Loading branch information
achirkin committed Sep 26, 2024
1 parent b138a07 commit 0bfb6be
Show file tree
Hide file tree
Showing 18 changed files with 155 additions and 130 deletions.
5 changes: 2 additions & 3 deletions cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,15 @@ using descriptor_instances = instance_selector<
template <typename DataT, typename IndexT, typename DistanceT, typename DatasetT>
auto dataset_descriptor_init(const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric,
rmm::cuda_stream_view stream)
cuvs::distance::DistanceType metric)
-> dataset_descriptor_host<DataT, IndexT, DistanceT>
{
auto [init, priority] =
descriptor_instances::select<DataT, IndexT, DistanceT>(params, dataset, metric);
if (init == nullptr || priority < 0) {
RAFT_FAIL("No dataset descriptor instance compiled for this parameter combination.");
}
return init(params, dataset, metric, stream);
return init(params, dataset, metric);
}

} // namespace cuvs::neighbors::cagra::detail
62 changes: 45 additions & 17 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <functional>
#include <memory>
#include <type_traits>
#include <variant>

namespace cuvs::neighbors::cagra::detail {

Expand Down Expand Up @@ -222,31 +223,61 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
* The host struct manages the lifetime of the associated device pointer and a couple parameters
* affecting the search kernel launch config.
*
* [Note: lazy initialization]
* Initialization of the descriptor involves allocating device memory and calling a kernel.
* This can interfere with other workloads (such as the persistent kernel) and generally adds
* overhead. To mitigate this, we don't call any CUDA api at the construction of the descriptor
* host. Instead, we postpone the initialization till the device pointer is requested.
*
*/
template <typename DataT, typename IndexT, typename DistanceT>
struct dataset_descriptor_host {
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
using dd_ptr_t = std::shared_ptr<dev_descriptor_t>;
using init_f =
std::tuple<std::function<void(dev_descriptor_t*, rmm::cuda_stream_view stream)>, size_t>;
uint32_t smem_ws_size_in_bytes = 0;
uint32_t team_size = 0;

template <typename DescriptorImpl>
dataset_descriptor_host(const DescriptorImpl& dd_host, rmm::cuda_stream_view stream)
: dev_ptr_{[stream]() {
dev_descriptor_t* p;
RAFT_CUDA_TRY(cudaMallocAsync(&p, sizeof(DescriptorImpl), stream));
return p;
}(),
[stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }},
template <typename DescriptorImpl, typename InitF>
dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init)
: value_{std::make_tuple(init, sizeof(DescriptorImpl))},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()}
{
}

[[nodiscard]] auto dev_ptr() const -> const dev_descriptor_t* { return dev_ptr_.get(); }
[[nodiscard]] auto dev_ptr() -> dev_descriptor_t* { return dev_ptr_.get(); }
/**
* Return the device pointer, possibly evaluating it in the given thread.
*/
[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
}
[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
}

private:
std::unique_ptr<dev_descriptor_t, std::function<void(dev_descriptor_t*)>> dev_ptr_;
mutable std::variant<dd_ptr_t, init_f> value_;

static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t
{
using raft::RAFT_NAME;
auto& [fun, size] = init;
dd_ptr_t dev_ptr{
[stream, s = size]() {
dev_descriptor_t* p;
RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream));
return p;
}(),
[stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }};
fun(dev_ptr.get(), stream);
return dev_ptr;
}
};

/**
Expand All @@ -257,11 +288,8 @@ struct dataset_descriptor_host {
*
*/
template <typename DataT, typename IndexT, typename DistanceT, typename DatasetT>
using init_desc_type =
dataset_descriptor_host<DataT, IndexT, DistanceT> (*)(const cagra::search_params&,
const DatasetT&,
cuvs::distance::DistanceType,
rmm::cuda_stream_view);
using init_desc_type = dataset_descriptor_host<DataT, IndexT, DistanceT> (*)(
const cagra::search_params&, const DatasetT&, cuvs::distance::DistanceType);

/**
* @brief Descriptor instance specification.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,14 @@
template <typename DataT, typename IndexT, typename DistanceT, typename DatasetT>
auto dataset_descriptor_init(const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric,
rmm::cuda_stream_view stream)
cuvs::distance::DistanceType metric)
-> dataset_descriptor_host<DataT, IndexT, DistanceT>
{{
auto [init, priority] = descriptor_instances::select<DataT, IndexT, DistanceT>(params, dataset, metric);
if (init == nullptr || priority < 0) {{
RAFT_FAIL("No dataset descriptor instance compiled for this parameter combination.");
}}
return init(params, dataset, metric, stream);
return init(params, dataset, metric);
}}
'''
f.write(template.format(includes=includes, content=contents))
Expand Down
30 changes: 13 additions & 17 deletions cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -252,28 +252,24 @@ template <cuvs::distance::DistanceType Metric,
typename DistanceT>
dataset_descriptor_host<DataT, IndexT, DistanceT>
standard_descriptor_spec<Metric, TeamSize, DatasetBlockDim, DataT, IndexT, DistanceT>::init_(
const cagra::search_params& params,
const DataT* ptr,
IndexT size,
uint32_t dim,
uint32_t ld,
rmm::cuda_stream_view stream)
const cagra::search_params& params, const DataT* ptr, IndexT size, uint32_t dim, uint32_t ld)
{
using desc_type =
standard_dataset_descriptor_t<Metric, TeamSize, DatasetBlockDim, DataT, IndexT, DistanceT>;
using base_type = typename desc_type::base_type;
desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld};
host_type result{dd_host, stream};

standard_dataset_descriptor_init_kernel<Metric,
TeamSize,
DatasetBlockDim,
DataT,
IndexT,
DistanceT>
<<<1, 1, 0, stream>>>(result.dev_ptr(), ptr, size, dim, desc_type::ld(dd_host.args));
RAFT_CUDA_TRY(cudaPeekAtLastError());
return result;
return host_type{dd_host,
[=](dataset_descriptor_base_t<DataT, IndexT, DistanceT>* dev_ptr,
rmm::cuda_stream_view stream) {
standard_dataset_descriptor_init_kernel<Metric,
TeamSize,
DatasetBlockDim,
DataT,
IndexT,
DistanceT>
<<<1, 1, 0, stream>>>(dev_ptr, ptr, size, dim, ld);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}};
}

} // namespace cuvs::neighbors::cagra::detail
14 changes: 4 additions & 10 deletions cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,13 @@ struct standard_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT>
template <typename DatasetT>
static auto init(const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric,
rmm::cuda_stream_view stream) -> host_type
cuvs::distance::DistanceType metric) -> host_type
{
return init_(params,
dataset.view().data_handle(),
IndexT(dataset.n_rows()),
dataset.dim(),
dataset.stride(),
stream);
dataset.stride());
}

template <typename DatasetT>
Expand All @@ -69,12 +67,8 @@ struct standard_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT>
}

private:
static dataset_descriptor_host<DataT, IndexT, DistanceT> init_(const cagra::search_params& params,
const DataT* ptr,
IndexT size,
uint32_t dim,
uint32_t ld,
rmm::cuda_stream_view stream);
static dataset_descriptor_host<DataT, IndexT, DistanceT> init_(
const cagra::search_params& params, const DataT* ptr, IndexT size, uint32_t dim, uint32_t ld);
};

} // namespace cuvs::neighbors::cagra::detail
42 changes: 22 additions & 20 deletions cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,7 @@ vpq_descriptor_spec<Metric,
const CodebookT* vq_code_book_ptr,
const CodebookT* pq_code_book_ptr,
IndexT size,
uint32_t dim,
rmm::cuda_stream_view stream)
uint32_t dim)
{
using desc_type = cagra_q_dataset_descriptor_t<Metric,
TeamSize,
Expand All @@ -443,24 +442,27 @@ vpq_descriptor_spec<Metric,
pq_code_book_ptr,
size,
dim};
host_type result{dd_host, stream};
vpq_dataset_descriptor_init_kernel<Metric,
TeamSize,
DatasetBlockDim,
PqBits,
PqLen,
CodebookT,
DataT,
IndexT,
DistanceT><<<1, 1, 0, stream>>>(result.dev_ptr(),
encoded_dataset_ptr,
encoded_dataset_dim,
vq_code_book_ptr,
pq_code_book_ptr,
size,
dim);
RAFT_CUDA_TRY(cudaPeekAtLastError());
return result;
return host_type{dd_host,
[=](dataset_descriptor_base_t<DataT, IndexT, DistanceT>* dev_ptr,
rmm::cuda_stream_view stream) {
vpq_dataset_descriptor_init_kernel<Metric,
TeamSize,
DatasetBlockDim,
PqBits,
PqLen,
CodebookT,
DataT,
IndexT,
DistanceT>
<<<1, 1, 0, stream>>>(dev_ptr,
encoded_dataset_ptr,
encoded_dataset_dim,
vq_code_book_ptr,
pq_code_book_ptr,
size,
dim);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}};
}

} // namespace cuvs::neighbors::cagra::detail
9 changes: 3 additions & 6 deletions cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,15 @@ struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
template <typename DatasetT>
static auto init(const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric,
rmm::cuda_stream_view stream) -> host_type
cuvs::distance::DistanceType metric) -> host_type
{
return init_(params,
dataset.data.data_handle(),
dataset.encoded_row_length(),
dataset.vq_code_book.data_handle(),
dataset.pq_code_book.data_handle(),
IndexT(dataset.n_rows()),
dataset.dim(),
stream);
dataset.dim());
}

template <typename DatasetT>
Expand All @@ -93,8 +91,7 @@ struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
const CodebookT* vq_code_book_ptr,
const CodebookT* pq_code_book_ptr,
IndexT size,
uint32_t dim,
rmm::cuda_stream_view stream);
uint32_t dim);
};

} // namespace cuvs::neighbors::cagra::detail
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res,
->value;
std::shared_ptr<desc_t> desc{nullptr};
if (!cache.get(key, &desc)) {
desc = std::make_shared<desc_t>(std::move(dataset_descriptor_init<DataT, IndexT, DistanceT>(
params, dataset, metric, raft::resource::get_cuda_stream(res))));
desc = std::make_shared<desc_t>(
std::move(dataset_descriptor_init<DataT, IndexT, DistanceT>(params, dataset, metric)));
cache.set(key, desc);
}
return *desc;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
SAMPLE_FILTER_T sample_filter)
{
cudaStream_t stream = raft::resource::get_cuda_stream(res);
select_and_run(dataset_desc.dev_ptr(),
select_and_run(dataset_desc,
graph,
intermediate_indices.data(),
intermediate_distances.data(),
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {

#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \
template void select_and_run<DataT, IndexT, DistanceT, SampleFilterT>( \
const dataset_descriptor_base_t<DataT, IndexT, DistanceT>* dataset_desc, \
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc, \
raft::device_matrix_view<const IndexT, int64_t, raft::row_major> graph, \
IndexT* topk_indices_ptr, \
DistanceT* topk_distances_ptr, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ struct search_kernel_config {
};

template <typename DataT, typename IndexT, typename DistanceT, typename SampleFilterT>
void select_and_run(const dataset_descriptor_base_t<DataT, IndexT, DistanceT>* dataset_desc,
void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
raft::device_matrix_view<const IndexT, int64_t, raft::row_major> graph,
IndexT* topk_indices_ptr, // [num_queries, topk]
DistanceT* topk_distances_ptr, // [num_queries, topk]
Expand Down Expand Up @@ -455,7 +455,7 @@ void select_and_run(const dataset_descriptor_base_t<DataT, IndexT, DistanceT>* d

kernel<<<grid_dims, block_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
dataset_desc,
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
graph.extent(1),
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace cuvs::neighbors::cagra::detail::multi_cta_search {

template <typename DataT, typename IndexT, typename DistanceT, typename SampleFilterT>
void select_and_run(const dataset_descriptor_base_t<DataT, IndexT, DistanceT>* dataset_desc,
void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
raft::device_matrix_view<const IndexT, int64_t, raft::row_major> graph,
IndexT* topk_indices_ptr, // [num_queries, topk]
DistanceT* topk_distances_ptr, // [num_queries, topk]
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void random_pickup(const dataset_descriptor_host<DataT, IndexT, DistanceT>& data
num_queries);

random_pickup_kernel<<<grid_size, block_size, dataset_desc.smem_ws_size_in_bytes, cuda_stream>>>(
dataset_desc.dev_ptr(),
dataset_desc.dev_ptr(cuda_stream),
queries_ptr,
num_pickup,
num_distilation,
Expand Down Expand Up @@ -410,7 +410,7 @@ void compute_distance_to_child_nodes(
parent_distance_ptr,
lds,
search_width,
dataset_desc.dev_ptr(),
dataset_desc.dev_ptr(cuda_stream),
neighbor_graph_ptr,
graph_degree,
query_ptr,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ struct search_plan_impl : public search_plan_impl_base {
const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
std::uint32_t* const num_executed_iterations, // [num_queries]
uint32_t topk,
SAMPLE_FILTER_T sample_filter) {};
SAMPLE_FILTER_T sample_filter){};

void adjust_search_params()
{
Expand Down
4 changes: 1 addition & 3 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp

#include <rmm/device_uvector.hpp>

#include <algorithm>
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -218,7 +216,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
SAMPLE_FILTER_T sample_filter)
{
cudaStream_t stream = raft::resource::get_cuda_stream(res);
select_and_run(dataset_desc.dev_ptr(),
select_and_run(dataset_desc,
graph,
result_indices_ptr,
result_distances_ptr,
Expand Down
Loading

0 comments on commit 0bfb6be

Please sign in to comment.