Skip to content

Commit

Permalink
Generate dataset of select_k times (#1497)
Browse files Browse the repository at this point in the history
This adds an optional flag (`--select_k_dataset`) to the MATRIX_BENCH that will turn on generating a grid search of benchmarks for different select_k algorithms.  Since this adds about 100x as many benchmarks to run as previous (90k vs 900), this is opt-in only right now. This will be used to learn a heuristic function in #1455

This also integrates the faiss block select top-k algorithm into this benchmarking, so that we can compare how it performs against the other select_k algorithms

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1497
  • Loading branch information
benfred authored May 11, 2023
1 parent cc4a76b commit 1d1c523
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 8 deletions.
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/selection_faiss_size_t_double.cu
src/neighbors/detail/selection_faiss_size_t_float.cu
src/neighbors/detail/selection_faiss_uint32_t_float.cu
src/neighbors/detail/selection_faiss_int64_t_double.cu
src/neighbors/detail/selection_faiss_int64_t_half.cu
src/neighbors/detail/selection_faiss_uint32_t_double.cu
src/neighbors/detail/selection_faiss_uint32_t_half.cu
src/neighbors/ivf_flat_build_float_int64_t.cu
src/neighbors/ivf_flat_build_int8_t_int64_t.cu
src/neighbors/ivf_flat_build_uint8_t_int64_t.cu
Expand Down
12 changes: 10 additions & 2 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,16 @@ if(BUILD_PRIMS_BENCH)
)

ConfigureBench(
NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
NAME
MATRIX_BENCH
PATH
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/matrix/main.cpp
OPTIONAL
LIB
EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(
Expand Down
41 changes: 41 additions & 0 deletions cpp/bench/prims/matrix/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmark/benchmark.h>
#include <cstring>

namespace raft::matrix {
void add_select_k_dataset_benchmarks();
}

int main(int argc, char** argv)
{
// if we're passed a 'select_k_dataset' flag, add in extra benchmarks
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "--select_k_dataset") == 0) {
raft::matrix::add_select_k_dataset_benchmarks();

// pop off the cmdline argument from argc/argv
for (int j = i; j < argc - 1; ++j)
argv[j] = argv[j + 1];
argc--;
break;
}
}
benchmark::Initialize(&argc, argv);
if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
benchmark::RunSpecifiedBenchmarks();
}
93 changes: 89 additions & 4 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <type_traits>

namespace raft::matrix {

using namespace raft::bench; // NOLINT

template <typename KeyT, typename IdxT, select::Algo Algo>
Expand Down Expand Up @@ -72,17 +71,16 @@ struct selection : public fixture {

void run_benchmark(::benchmark::State& state) override // NOLINT
{
device_resources handle{stream};
try {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; }
state.SetLabel(label_stream.str());
loop_on_state(state, [this, &handle]() {
loop_on_state(state, [this]() {
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
in_ids_.data(),
params_.use_index_input ? in_ids_.data() : NULL,
params_.batch_size,
params_.len,
params_.k,
Expand Down Expand Up @@ -182,4 +180,91 @@ SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT

// For learning a heuristic of which selection algorithm to use, we
// have a couple of additional constraints when generating the dataset:
// 1. We want these benchmarks to be optionally enabled from the commandline -
// there are thousands of them, and the run-time is non-trivial. This should be opt-in only
// 2. We test out larger k values - that won't work for all algorithms. This requires filtering
// the input parameters per algorithm.
// This makes the code to generate this dataset different from the code above to
// register other benchmarks
#define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \
{ \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
std::stringstream name; \
name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \
<< input.len << "/" << input.k << "/" << input.use_index_input; \
auto* b = ::benchmark::internal::RegisterBenchmarkInternal( \
new raft::bench::internal::Fixture<SelectK, select::params>(name.str(), input)); \
b->UseManualTime(); \
b->Unit(benchmark::kMillisecond); \
}

const static size_t MAX_MEMORY = 16 * 1024 * 1024 * 1024ULL;

// registers the input for all algorithms
#define SELECTION_REGISTER_INPUT(KeyT, IdxT, input) \
{ \
size_t mem = input.batch_size * input.len * (sizeof(KeyT) + sizeof(IdxT)); \
if (mem < MAX_MEMORY) { \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix8bits, input) \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix11bits, input) \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix11bitsExtraPass, input) \
if (input.k <= raft::matrix::detail::select::warpsort::kMaxCapacity) { \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpImmediate, input) \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpFiltered, input) \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpDistributed, input) \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpDistributedShm, input) \
} \
if (input.k <= raft::neighbors::detail::kFaissMaxK<IdxT, KeyT>()) { \
SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kFaissBlockSelect, input) \
} \
} \
}

void add_select_k_dataset_benchmarks()
{
// define a uniform grid
std::vector<select::params> inputs;

size_t grid_increment = 1;
std::vector<int> k_vals;
for (size_t k = 0; k < 13; k += grid_increment) {
k_vals.push_back(1 << k);
}
// Add in values just past the limit for warp/faiss select
k_vals.push_back(257);
k_vals.push_back(2049);

const static bool select_min = true;
const static bool use_ids = false;

for (size_t row = 0; row < 13; row += grid_increment) {
for (size_t col = 10; col < 28; col += grid_increment) {
for (auto k : k_vals) {
inputs.push_back(
select::params{size_t(1 << row), size_t(1 << col), k, select_min, use_ids});
}
}
}

// also add in some random values
std::default_random_engine rng(42);
std::uniform_real_distribution<> row_dist(0, 13);
std::uniform_real_distribution<> col_dist(10, 28);
std::uniform_real_distribution<> k_dist(0, 13);
for (size_t i = 0; i < 1024; ++i) {
auto row = static_cast<size_t>(pow(2, row_dist(rng)));
auto col = static_cast<size_t>(pow(2, col_dist(rng)));
auto k = static_cast<int>(pow(2, k_dist(rng)));
inputs.push_back(select::params{row, col, k, select_min, use_ids});
}

for (auto& input : inputs) {
SELECTION_REGISTER_INPUT(double, int64_t, input);
SELECTION_REGISTER_INPUT(double, uint32_t, input);
SELECTION_REGISTER_INPUT(float, int64_t, input);
SELECTION_REGISTER_INPUT(float, uint32_t, input);
}
}
} // namespace raft::matrix
6 changes: 6 additions & 0 deletions cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cstddef> // size_t
#include <cstdint> // uint32_t
#include <cuda_fp16.h> // __half
#include <raft/neighbors/detail/selection_faiss_helpers.cuh> // kFaissMaxK
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

Expand Down Expand Up @@ -58,4 +59,9 @@ instantiate_raft_neighbors_detail_select_k(size_t, double);
instantiate_raft_neighbors_detail_select_k(int, double);
instantiate_raft_neighbors_detail_select_k(size_t, float);

instantiate_raft_neighbors_detail_select_k(uint32_t, double);
instantiate_raft_neighbors_detail_select_k(int64_t, double);
instantiate_raft_neighbors_detail_select_k(uint32_t, __half);
instantiate_raft_neighbors_detail_select_k(int64_t, __half);

#undef instantiate_raft_neighbors_detail_select_k
9 changes: 7 additions & 2 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/matrix/detail/select_radix.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/neighbors/detail/selection_faiss.cuh>

namespace raft::matrix::select {

Expand Down Expand Up @@ -52,7 +53,8 @@ enum class Algo {
kWarpImmediate,
kWarpFiltered,
kWarpDistributed,
kWarpDistributedShm
kWarpDistributedShm,
kFaissBlockSelect
};

inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream&
Expand All @@ -67,6 +69,7 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream&
case Algo::kWarpFiltered: return os << "kWarpFiltered";
case Algo::kWarpDistributed: return os << "kWarpDistributed";
case Algo::kWarpDistributedShm: return os << "kWarpDistributedShm";
case Algo::kFaissBlockSelect: return os << "kFaissBlockSelect";
default: return os << "unknown enum value";
}
}
Expand Down Expand Up @@ -154,7 +157,9 @@ void select_k_impl(const device_resources& handle,
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in, in_idx, batch_size, len, out, out_idx, select_min, k, stream);
}
}

} // namespace raft::matrix::select
4 changes: 4 additions & 0 deletions cpp/src/neighbors/detail/selection_faiss_00_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@

types = dict(
uint32_t_float=("uint32_t", "float"),
uint32_t_double=("uint32_t", "double"),
uint32_t_half=("uint32_t", "half"),
int64_t_double=("int64_t", "double"),
int64_t_half=("int64_t", "half"),
int32_t_float=("int32_t", "float"),
long_float=("long", "float"),
size_t_double=("size_t", "double"),
Expand Down
44 changes: 44 additions & 0 deletions cpp/src/neighbors/detail/selection_faiss_int64_t_double.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by selection_faiss_00_generate.py
*
* Make changes there and run in this directory:
*
* > python selection_faiss_00_generate.py
*
*/

#include <cstddef> // size_t
#include <cstdint> // uint32_t
#include <raft/neighbors/detail/selection_faiss-inl.cuh>

#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \
template void raft::neighbors::detail::select_k(const key_t* inK, \
const payload_t* inV, \
size_t n_rows, \
size_t n_cols, \
key_t* outK, \
payload_t* outV, \
bool select_min, \
int k, \
cudaStream_t stream)

instantiate_raft_neighbors_detail_select_k(int64_t, double);

#undef instantiate_raft_neighbors_detail_select_k
44 changes: 44 additions & 0 deletions cpp/src/neighbors/detail/selection_faiss_int64_t_half.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by selection_faiss_00_generate.py
*
* Make changes there and run in this directory:
*
* > python selection_faiss_00_generate.py
*
*/

#include <cstddef> // size_t
#include <cstdint> // uint32_t
#include <raft/neighbors/detail/selection_faiss-inl.cuh>

#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \
template void raft::neighbors::detail::select_k(const key_t* inK, \
const payload_t* inV, \
size_t n_rows, \
size_t n_cols, \
key_t* outK, \
payload_t* outV, \
bool select_min, \
int k, \
cudaStream_t stream)

instantiate_raft_neighbors_detail_select_k(int64_t, half);

#undef instantiate_raft_neighbors_detail_select_k
44 changes: 44 additions & 0 deletions cpp/src/neighbors/detail/selection_faiss_uint32_t_double.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by selection_faiss_00_generate.py
*
* Make changes there and run in this directory:
*
* > python selection_faiss_00_generate.py
*
*/

#include <cstddef> // size_t
#include <cstdint> // uint32_t
#include <raft/neighbors/detail/selection_faiss-inl.cuh>

#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \
template void raft::neighbors::detail::select_k(const key_t* inK, \
const payload_t* inV, \
size_t n_rows, \
size_t n_cols, \
key_t* outK, \
payload_t* outV, \
bool select_min, \
int k, \
cudaStream_t stream)

instantiate_raft_neighbors_detail_select_k(uint32_t, double);

#undef instantiate_raft_neighbors_detail_select_k
Loading

0 comments on commit 1d1c523

Please sign in to comment.