Skip to content

Commit

Permalink
CAGRA pad dataset for 128bit vectorized load (#1505)
Browse files Browse the repository at this point in the history
This PR adds padding to the dataset (if necessary) to make reading any of its rows compatible with 128bit vectorized loads. This change also enables handling arbitrary number of input features (before this PR each row had to be at least 64bit aligned, which constrained the acceptable number of input features).

Fixes #1458.

With this change, it is sufficient to keep a single "load type" specialization for the search kernels, which shall cut the binary size by half (#1459).

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - tsuki (https://github.com/enp1s0)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1505
  • Loading branch information
tfeher authored Jun 9, 2023
1 parent 567dfd7 commit 6ec78e9
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 170 deletions.
22 changes: 14 additions & 8 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,28 @@ index<T, IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset)
{
size_t degree = params.intermediate_graph_degree;
if (degree >= static_cast<size_t>(dataset.extent(0))) {
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
dataset.extent(0));
degree = dataset.extent(0) - 1;
intermediate_degree = dataset.extent(0) - 1;
}
if (intermediate_degree < graph_degree) {
RAFT_LOG_WARN(
"Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing "
"graph_degree.",
graph_degree,
intermediate_degree);
graph_degree = intermediate_degree;
}
RAFT_EXPECTS(degree >= params.graph_degree,
"Intermediate graph degree cannot be smaller than final graph degree");

auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), degree);
auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), intermediate_degree);

build_knn_graph(res, dataset, knn_graph.view());

auto cagra_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), params.graph_degree);
auto cagra_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), graph_degree);

prune<IdxT>(res, knn_graph.view(), cagra_graph.view());

Expand Down Expand Up @@ -290,7 +297,6 @@ void search(raft::resources const& res,

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

Expand Down
48 changes: 37 additions & 11 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>

#include <memory>
#include <optional>
Expand Down Expand Up @@ -82,8 +83,6 @@ struct search_params : ann::search_params {
/** Lower limit of search iterations. */
size_t min_iterations = 0;

/** Bit length for reading the dataset vectors. 0, 64 or 128. Auto selection when 0. */
size_t load_bit_length = 0;
/** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */
size_t thread_block_size = 0;
/** Hashmap type. Auto selection when AUTO. */
Expand Down Expand Up @@ -113,6 +112,7 @@ static_assert(std::is_aggregate_v<search_params>);
*/
template <typename T, typename IdxT>
struct index : ann::index {
using AlignDim = raft::Pow2<16 / sizeof(T)>;
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

Expand All @@ -124,12 +124,15 @@ struct index : ann::index {
}

// /** Total length of the index. */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset_.extent(0); }
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT
{
return dataset_view_.extent(0);
}

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return dataset_.extent(1);
return dataset_view_.extent(1);
}
/** Graph degree */
[[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t
Expand All @@ -138,9 +141,10 @@ struct index : ann::index {
}

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view<const T, IdxT, row_major>
[[nodiscard]] inline auto dataset() const noexcept
-> device_matrix_view<const T, IdxT, layout_stride>
{
return dataset_.view();
return dataset_view_;
}

/** neighborhood graph [size, graph-degree] */
Expand Down Expand Up @@ -179,15 +183,36 @@ struct index : ann::index {
mdspan<IdxT, matrix_extent<IdxT>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, IdxT>(res, dataset.extent(0), dataset.extent(1))),
dataset_(
make_device_matrix<T, IdxT>(res, dataset.extent(0), AlignDim::roundUp(dataset.extent(1)))),
graph_(make_device_matrix<IdxT, IdxT>(res, knn_graph.extent(0), knn_graph.extent(1)))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
dataset.size(),
resource::get_cuda_stream(res));
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
dataset.size(),
resource::get_cuda_stream(res));
} else {
// copy with padding
RAFT_CUDA_TRY(cudaMemsetAsync(
dataset_.data_handle(), 0, dataset_.size() * sizeof(T), resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dataset_.data_handle(),
sizeof(T) * dataset_.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.extent(1),
sizeof(T) * dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
}
dataset_view_ = make_device_strided_matrix_view<T, IdxT>(
dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1));
RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu",
static_cast<size_t>(dataset_view_.extent(0)),
static_cast<size_t>(dataset_view_.extent(1)),
static_cast<size_t>(dataset_view_.stride(0)));
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
Expand All @@ -199,6 +224,7 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
raft::device_matrix<T, IdxT, row_major> dataset_;
raft::device_matrix<IdxT, IdxT, row_major> graph_;
raft::device_matrix_view<T, IdxT, layout_stride> dataset_view_;
};

/** @} */
Expand Down
4 changes: 0 additions & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(
dataset.extent(1) * sizeof(DataT) % 8 == 0,
"Dataset rows are expected to have at least 8 bytes alignment. Try padding feature dims.");

RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down
9 changes: 5 additions & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
#include <rmm/cuda_stream_view.hpp>

#include "factory.cuh"
#include "search_multi_cta.cuh"
#include "search_multi_kernel.cuh"
#include "search_plan.cuh"
#include "search_single_cta.cuh"

Expand Down Expand Up @@ -92,8 +90,11 @@ void search_main(raft::resources const& res,
: nullptr;
uint32_t* _num_executed_iterations = nullptr;

auto dataset_internal = raft::make_device_matrix_view<const T, internal_IdxT, row_major>(
index.dataset().data_handle(), index.dataset().extent(0), index.dataset().extent(1));
auto dataset_internal = make_device_strided_matrix_view<const T, internal_IdxT, row_major>(
index.dataset().data_handle(),
index.dataset().extent(0),
index.dataset().extent(1),
index.dataset().stride(0));
auto graph_internal =
raft::make_device_matrix_view<const internal_IdxT, internal_IdxT, row_major>(
reinterpret_cast<const internal_IdxT*>(index.graph().data_handle()),
Expand Down
20 changes: 16 additions & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
namespace raft::neighbors::experimental::cagra::detail {

// Serialization version 1.
constexpr int serialization_version = 1;
constexpr int serialization_version = 2;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand All @@ -36,7 +36,8 @@ struct check_index_layout {
"paste in the new size and consider updating the serialization logic");
};

template struct check_index_layout<sizeof(index<double, std::uint64_t>), 136>;
constexpr size_t expected_size = 176;
template struct check_index_layout<sizeof(index<double, std::uint64_t>), expected_size>;

/**
* Save the index to file.
Expand All @@ -59,7 +60,19 @@ void serialize(raft::resources const& res, std::ostream& os, const index<T, IdxT
serialize_scalar(res, os, index_.dim());
serialize_scalar(res, os, index_.graph_degree());
serialize_scalar(res, os, index_.metric());
serialize_mdspan(res, os, index_.dataset());
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, IdxT>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
serialize_mdspan(res, os, index_.graph());
}

Expand Down Expand Up @@ -100,7 +113,6 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>

auto dataset = raft::make_host_matrix<T, IdxT>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, IdxT>(n_rows, graph_degree);

deserialize_mdspan(res, is, dataset.view());
deserialize_mdspan(res, is, graph.view());

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim]
const std::size_t dataset_dim,
const std::size_t dataset_size,
const std::size_t dataset_ld,
const std::size_t num_pickup,
const unsigned num_distilation,
const uint64_t rand_xor_mask,
Expand Down Expand Up @@ -93,7 +94,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
for (uint32_t e = 0; e < nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * seed_index)))[0];
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * seed_index)))[0];
}
#pragma unroll
for (uint32_t e = 0; e < nelem; e++) {
Expand Down Expand Up @@ -146,6 +147,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
// [dataset_dim, dataset_size]
const DATA_T* const dataset_ptr,
const std::size_t dataset_dim,
const std::size_t dataset_ld,
// [knn_k, dataset_size]
const INDEX_T* const knn_graph,
const std::uint32_t knn_k,
Expand Down Expand Up @@ -215,7 +217,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
for (unsigned e = 0; e < nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * child_id)))[0];
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * child_id)))[0];
}
#pragma unroll
for (unsigned e = 0; e < nelem; e++) {
Expand Down
Loading

0 comments on commit 6ec78e9

Please sign in to comment.