Skip to content

Commit

Permalink
Replace mdarrays with rmm::device_uvector to workaround a crash in cuml
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Jul 26, 2022
1 parent e6a815b commit 8b26750
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 64 deletions.
38 changes: 22 additions & 16 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ inline auto extend(const handle_t& handle,
IdxT n_rows,
rmm::cuda_stream_view stream) -> index<T, IdxT>
{
auto n_lists = orig_index.n_lists();
auto dim = orig_index.dim();
auto n_lists = orig_index.n_lists;
auto dim = orig_index.dim;
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::extend(%zu, %u)", size_t(n_rows), dim);

RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0,
RAFT_EXPECTS(new_indices != nullptr || orig_index.size == 0,
"You must pass data indices when the index is non-empty.");

rmm::device_uvector<uint32_t> new_labels(n_rows, stream);
Expand All @@ -130,12 +130,12 @@ inline auto extend(const handle_t& handle,
orig_index.metric,
stream);

auto&& list_sizes = make_device_mdarray<uint32_t>(stream, n_lists);
auto&& list_offsets = make_device_mdarray<IdxT>(stream, n_lists + 1);
auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
auto list_sizes_ptr = list_sizes.data();
auto list_offsets_ptr = list_offsets.data();

auto&& centers = make_device_mdarray<float>(stream, n_lists, dim);
auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);
auto centers_ptr = centers.data();

// Calculate the centers and sizes on the new data, starting from the original values
Expand Down Expand Up @@ -164,11 +164,11 @@ inline auto extend(const handle_t& handle,
update_host(&index_size, list_offsets_ptr + n_lists, 1, stream);
handle.sync_stream(stream);

auto&& data = make_device_mdarray<T>(stream, index_size, dim);
auto&& indices = make_device_mdarray<IdxT>(stream, index_size);
auto&& data = rmm::device_uvector<T>(index_size * IdxT(dim), stream);
auto&& indices = rmm::device_uvector<IdxT>(index_size, stream);

// Populate index with the old data
if (orig_index.size() > 0) {
if (orig_index.size > 0) {
utils::block_copy(orig_index.list_offsets.data(),
list_offsets_ptr,
IdxT(n_lists),
Expand Down Expand Up @@ -206,10 +206,10 @@ inline auto extend(const handle_t& handle,

// Precompute the centers vector norms for L2Expanded distance
auto compute_norms = [&]() {
auto&& r = make_device_mdarray<float>(stream, n_lists);
auto&& r = rmm::device_uvector<float>(n_lists, stream);
utils::dots_along_rows(n_lists, dim, centers.data(), r.data(), stream);
RAFT_LOG_TRACE_VEC(r.data(), 20);
return r;
return std::move(r);
};
auto&& center_norms = orig_index.metric == raft::distance::DistanceType::L2Expanded
? std::optional(compute_norms())
Expand All @@ -219,6 +219,9 @@ inline auto extend(const handle_t& handle,
index<T, IdxT> new_index{{},
orig_index.veclen,
orig_index.metric,
index_size,
orig_index.dim,
orig_index.n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
Expand Down Expand Up @@ -256,7 +259,7 @@ inline auto build(const handle_t& handle,
auto n_lists = static_cast<uint32_t>(params.n_lists);

// kmeans cluster ids for the dataset
auto&& centers = make_device_mdarray<float>(stream, n_lists, dim);
auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);

// Predict labels of the whole dataset
kmeans::build_optimized_kmeans(handle,
Expand All @@ -270,17 +273,20 @@ inline auto build(const handle_t& handle,
params.metric,
stream);

auto&& data = make_device_mdarray<T>(stream, 0, dim);
auto&& indices = make_device_mdarray<IdxT>(stream, 0);
auto&& list_sizes = make_device_mdarray<uint32_t>(stream, n_lists);
auto&& list_offsets = make_device_mdarray<IdxT>(stream, n_lists + 1);
auto&& data = rmm::device_uvector<T>(0, stream);
auto&& indices = rmm::device_uvector<IdxT>(0, stream);
auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
utils::memzero(list_sizes.data(), list_sizes.size(), stream);
utils::memzero(list_offsets.data(), list_offsets.size(), stream);

// assemble the index
index<T, IdxT> index{{},
veclen,
params.metric,
IdxT(0),
dim,
n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
Expand Down
34 changes: 17 additions & 17 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ void launch_kernel(Lambda lambda,
interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, IdxT, Lambda>;
const int max_query_smem = 16384;
int query_smem_elems =
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim));
int smem_size = query_smem_elems * sizeof(T);
constexpr int kSubwarpSize = std::min<int>(Capacity, WarpSize);
smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide<AccT, size_t>(
Expand Down Expand Up @@ -861,10 +861,10 @@ void launch_kernel(Lambda lambda,
index.list_offsets.data(),
n_probes,
k,
index.dim(),
index.dim,
neighbors,
distances);
queries += grid_dim_y * index.dim();
queries += grid_dim_y * index.dim;
neighbors += grid_dim_y * grid_dim_x * k;
distances += grid_dim_y * grid_dim_x * k;
}
Expand Down Expand Up @@ -1072,7 +1072,7 @@ void search_impl(const handle_t& handle,
// The norm of query
rmm::device_uvector<float> query_norm_dev(n_queries, stream, search_mr);
// The distance value of cluster(list) and queries
rmm::device_uvector<float> distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr);
rmm::device_uvector<float> distance_buffer_dev(n_queries * index.n_lists, stream, search_mr);
// The topk distance value of cluster(list) and queries
rmm::device_uvector<float> coarse_distances_dev(n_queries * n_probes, stream, search_mr);
// The topk index of cluster(list) and queries
Expand All @@ -1084,7 +1084,7 @@ void search_impl(const handle_t& handle,

size_t float_query_size;
if constexpr (std::is_integral_v<T>) {
float_query_size = n_queries * index.dim();
float_query_size = n_queries * index.dim;
} else {
float_query_size = 0;
}
Expand All @@ -1095,7 +1095,7 @@ void search_impl(const handle_t& handle,
converted_queries_ptr = const_cast<float*>(queries);
} else {
linalg::unaryOp(
converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping<float>{}, stream);
converted_queries_ptr, queries, n_queries * index.dim, utils::mapping<float>{}, stream);
}

float alpha = 1.0f;
Expand All @@ -1105,11 +1105,11 @@ void search_impl(const handle_t& handle,
alpha = -2.0f;
beta = 1.0f;
utils::dots_along_rows(
n_queries, index.dim(), converted_queries_ptr, query_norm_dev.data(), stream);
n_queries, index.dim, converted_queries_ptr, query_norm_dev.data(), stream);
utils::outer_add(query_norm_dev.data(),
n_queries,
index.center_norms->data(),
index.n_lists(),
index.n_lists,
distance_buffer_dev.data(),
stream);
RAFT_LOG_TRACE_VEC(index.center_norms->data(), 20);
Expand All @@ -1122,25 +1122,25 @@ void search_impl(const handle_t& handle,
linalg::gemm(handle,
true,
false,
index.n_lists(),
index.n_lists,
n_queries,
index.dim(),
index.dim,
&alpha,
index.centers.data(),
index.dim(),
index.dim,
converted_queries_ptr,
index.dim(),
index.dim,
&beta,
distance_buffer_dev.data(),
index.n_lists(),
index.n_lists,
stream);

RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), 20);
if (n_probes <= raft::spatial::knn::detail::topk::kMaxCapacity) {
topk::warp_sort_topk<AccT, uint32_t>(distance_buffer_dev.data(),
nullptr,
n_queries,
index.n_lists(),
index.n_lists,
n_probes,
coarse_distances_dev.data(),
coarse_indices_dev.data(),
Expand All @@ -1151,7 +1151,7 @@ void search_impl(const handle_t& handle,
topk::radix_topk<AccT, uint32_t, 11, 512>(distance_buffer_dev.data(),
nullptr,
n_queries,
index.n_lists(),
index.n_lists,
n_probes,
coarse_distances_dev.data(),
coarse_indices_dev.data(),
Expand Down Expand Up @@ -1249,11 +1249,11 @@ inline void search(const handle_t& handle,
rmm::mr::device_memory_resource* mr = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim());
"ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim);

RAFT_EXPECTS(params.n_probes > 0,
"n_probes (number of clusters to probe in the search) must be positive.");
auto n_probes = std::min<uint32_t>(params.n_probes, index.n_lists());
auto n_probes = std::min<uint32_t>(params.n_probes, index.n_lists);

bool select_min;
switch (index.metric) {
Expand Down
49 changes: 18 additions & 31 deletions cpp/include/raft/spatial/knn/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

#include "common.hpp"

#include <raft/core/mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/distance/distance_type.hpp>
#include <raft/integer_utils.h>

#include <rmm/device_uvector.hpp>

#include <optional>

namespace raft::spatial::knn::ivf_flat {
Expand Down Expand Up @@ -62,6 +64,12 @@ struct index : knn::index {
const uint32_t veclen;
/** Distance metric used for clustering. */
const raft::distance::DistanceType metric;
/** Total length of the index. */
const IdxT size;
/** Dimensionality of the data. */
const uint32_t dim;
/** Number of clusters/inverted lists. */
const uint32_t n_lists;

/**
* Inverted list data [size, dim].
Expand All @@ -86,20 +94,20 @@ struct index : knn::index {
* x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - ,
*
*/
device_mdarray<T, extent_2d, row_major> data;
rmm::device_uvector<T> data;
/** Inverted list indices: ids of items in the source data [size] */
device_mdarray<IdxT, extent_1d, row_major> indices;
rmm::device_uvector<IdxT> indices;
/** Sizes of the lists (clusters) [n_lists] */
device_mdarray<uint32_t, extent_1d, row_major> list_sizes;
rmm::device_uvector<uint32_t> list_sizes;
/**
* Offsets into the lists [n_lists + 1].
* The last value contains the total length of the index.
*/
device_mdarray<IdxT, extent_1d, row_major> list_offsets;
rmm::device_uvector<IdxT> list_offsets;
/** k-means cluster centers corresponding to the lists [n_lists, dim] */
device_mdarray<float, extent_2d, row_major> centers;
rmm::device_uvector<float> centers;
/** (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists] */
std::optional<device_mdarray<float, extent_1d, row_major>> center_norms;
std::optional<rmm::device_uvector<float>> center_norms;

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
Expand All @@ -108,33 +116,12 @@ struct index : knn::index {
auto operator=(index&&) -> index& = default;
~index() = default;

/** Total length of the index. */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT
{
return static_cast<uint32_t>(data.extent(0));
}
/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return static_cast<uint32_t>(data.extent(1));
}
/** Number of clusters/inverted lists. */
[[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t
{
return static_cast<uint32_t>(centers.extent(0));
}

/** Throw an error if the index content is inconsistent. */
inline void check_consistency() const
{
RAFT_EXPECTS(dim() % veclen == 0, "dimensionality is not a multiple of the veclen");
RAFT_EXPECTS(data.extent(0) == indices.extent(0), "inconsistent index size");
RAFT_EXPECTS(data.extent(1) == centers.extent(1), "inconsistent data dimensionality");
RAFT_EXPECTS( //
(centers.extent(0) == list_sizes.extent(0)) && //
(centers.extent(0) + 1 == list_offsets.extent(0)) && //
(!center_norms.has_value() || centers.extent(0) == center_norms->extent(0)),
"inconsistent number of lists (clusters)");
RAFT_EXPECTS(dim % veclen == 0, "dimensionality is not a multiple of the veclen");
RAFT_EXPECTS(list_offsets.size() == list_sizes.size() + 1,
"inconsistent number of lists (clusters)");
RAFT_EXPECTS(reinterpret_cast<size_t>(data.data()) % (veclen * sizeof(T)) == 0,
"The data storage pointer is not aligned to the vector length");
}
Expand Down

0 comments on commit 8b26750

Please sign in to comment.