Skip to content

Commit

Permalink
Merge branch 'bug-ext-fix-knn-bench-compile' of github.com:teju85/raf…
Browse files Browse the repository at this point in the history
…t into bug-ext-fix-knn-bench-compile
  • Loading branch information
cjnolet committed Oct 2, 2022
2 parents 6471940 + 84e727e commit a7113c4
Show file tree
Hide file tree
Showing 27 changed files with 1,333 additions and 179 deletions.
7 changes: 7 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/specializations/fused_l2_nn_double_int.cu
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/random/specializations/rmat_rectangular_generator_int_double.cu
src/random/specializations/rmat_rectangular_generator_int64_double.cu
src/random/specializations/rmat_rectangular_generator_int_float.cu
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft")

# (please keep the filenames in alphabetical order)
add_executable(${RAFT_CPP_BENCH_TARGET}
bench/cluster/kmeans_balanced.cu
bench/cluster/kmeans.cu
bench/distance/distance_cosine.cu
bench/distance/distance_exp_l2.cu
bench/distance/distance_l1.cu
Expand Down
115 changes: 115 additions & 0 deletions cpp/bench/cluster/kmeans.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_params.hpp>

#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED
#include <raft/cluster/specializations.cuh>
#endif

namespace raft::bench::cluster {

struct KMeansBenchParams {
DatasetParams data;
BlobsParams blobs;
raft::cluster::KMeansParams kmeans;
};

template <typename T, typename IndexT = int>
struct KMeans : public BlobsFixture<T, IndexT> {
KMeans(const KMeansBenchParams& p) : BlobsFixture<T, IndexT>(p.data, p.blobs), params(p) {}

void run_benchmark(::benchmark::State& state) override
{
raft::device_matrix_view<const T, IndexT> X_view = this->X.view();
std::optional<raft::device_vector_view<const T, IndexT>> opt_weights_view = std::nullopt;
std::optional<raft::device_matrix_view<T, IndexT>> centroids_view =
std::make_optional<raft::device_matrix_view<T, IndexT>>(centroids.view());
raft::device_vector_view<IndexT, IndexT> labels_view = labels.view();
raft::host_scalar_view<T> inertia_view = raft::make_host_scalar_view<T>(&inertia);
raft::host_scalar_view<IndexT> n_iter_view = raft::make_host_scalar_view<IndexT>(&n_iter);

this->loop_on_state(state, [&]() {
raft::cluster::kmeans_fit_predict<T, IndexT>(this->handle,
params.kmeans,
X_view,
opt_weights_view,
centroids_view,
labels_view,
inertia_view,
n_iter_view);
});
}

void allocate_temp_buffers(const ::benchmark::State& state) override
{
centroids =
raft::make_device_matrix<T, IndexT>(this->handle, params.kmeans.n_clusters, params.data.cols);
labels = raft::make_device_vector<IndexT, IndexT>(this->handle, params.data.rows);
}

private:
KMeansBenchParams params;
raft::device_matrix<T, IndexT> centroids;
raft::device_vector<IndexT, IndexT> labels;
T inertia;
IndexT n_iter;
}; // struct KMeans

std::vector<KMeansBenchParams> getKMeansInputs()
{
std::vector<KMeansBenchParams> out;
KMeansBenchParams p;
p.data.row_major = true;
p.blobs.cluster_std = 1.0;
p.blobs.shuffle = false;
p.blobs.center_box_min = -10.0;
p.blobs.center_box_max = 10.0;
p.blobs.seed = 12345ULL;
p.kmeans.init = raft::cluster::KMeansParams::KMeansPlusPlus;
p.kmeans.max_iter = 300;
p.kmeans.tol = 1e-4;
p.kmeans.verbosity = RAFT_LEVEL_INFO;
p.kmeans.metric = raft::distance::DistanceType::L2Expanded;
p.kmeans.inertia_check = true;
std::vector<std::tuple<int, int, int>> row_cols_k = {
{1000000, 20, 1000},
{3000000, 50, 20},
{10000000, 50, 5},
};
for (auto& rck : row_cols_k) {
p.data.rows = std::get<0>(rck);
p.data.cols = std::get<1>(rck);
p.blobs.n_clusters = std::get<2>(rck);
p.kmeans.n_clusters = std::get<2>(rck);
for (auto bs_shift : std::vector<int>({16, 18})) {
p.kmeans.batch_samples = 1 << bs_shift;
out.push_back(p);
}
}
return out;
}

// note(lsugy): commenting out int64_t because the templates are not compiled in the distance
// library, resulting in long compilation times.
RAFT_BENCH_REGISTER((KMeans<float, int>), "", getKMeansInputs());
RAFT_BENCH_REGISTER((KMeans<double, int>), "", getKMeansInputs());
// RAFT_BENCH_REGISTER((KMeans<float, int64_t>), "", getKMeansInputs());
// RAFT_BENCH_REGISTER((KMeans<double, int64_t>), "", getKMeansInputs());

} // namespace raft::bench::cluster
110 changes: 110 additions & 0 deletions cpp/bench/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/random/rng.cuh>
#include <raft/spatial/knn/detail/ann_kmeans_balanced.cuh>

#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED
#include <raft/cluster/specializations.cuh>
#endif

namespace raft::bench::cluster {

struct KMeansBalancedBenchParams {
DatasetParams data;
uint32_t max_iter;
uint32_t n_lists;
raft::distance::DistanceType metric;
};

template <typename T, typename IndexT = int>
struct KMeansBalanced : public fixture {
KMeansBalanced(const KMeansBalancedBenchParams& p) : params(p) {}

void run_benchmark(::benchmark::State& state) override
{
this->loop_on_state(state, [this]() {
raft::spatial::knn::detail::kmeans::build_hierarchical<T>(this->handle,
this->params.max_iter,
(uint32_t)this->params.data.cols,
this->X.data_handle(),
this->params.data.rows,
this->centroids.data_handle(),
this->params.n_lists,
this->params.metric,
this->handle.get_stream());
});
}

void allocate_data(const ::benchmark::State& state) override
{
X = raft::make_device_matrix<T, IndexT>(handle, params.data.rows, params.data.cols);

raft::random::RngState rng{1234};
constexpr T kRangeMax = std::is_integral_v<T> ? std::numeric_limits<T>::max() : T(1);
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
} else {
raft::random::uniform(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
}
handle.sync_stream(stream);
}

void allocate_temp_buffers(const ::benchmark::State& state) override
{
centroids =
raft::make_device_matrix<float, IndexT>(this->handle, params.n_lists, params.data.cols);
}

private:
KMeansBalancedBenchParams params;
raft::device_matrix<T, IndexT> X;
raft::device_matrix<float, IndexT> centroids;
}; // struct KMeansBalanced

std::vector<KMeansBalancedBenchParams> getKMeansBalancedInputs()
{
std::vector<KMeansBalancedBenchParams> out;
KMeansBalancedBenchParams p;
p.data.row_major = true;
p.max_iter = 20;
p.metric = raft::distance::DistanceType::L2Expanded;
std::vector<std::pair<int, int>> row_cols = {
{100000, 128}, {1000000, 128}, {10000000, 128},
// The following dataset sizes are too large for most GPUs.
// {100000000, 128},
};
for (auto& rc : row_cols) {
p.data.rows = rc.first;
p.data.cols = rc.second;
for (auto n_lists : std::vector<int>({1000, 10000, 100000})) {
p.n_lists = n_lists;
out.push_back(p);
}
}
return out;
}

// Note: the datasets sizes are too large for 32-bit index types.
RAFT_BENCH_REGISTER((KMeansBalanced<float, int64_t>), "", getKMeansBalancedInputs());
RAFT_BENCH_REGISTER((KMeansBalanced<int8_t, int64_t>), "", getKMeansBalancedInputs());
RAFT_BENCH_REGISTER((KMeansBalanced<uint8_t, int64_t>), "", getKMeansBalancedInputs());

} // namespace raft::bench::cluster
79 changes: 75 additions & 4 deletions cpp/bench/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

#include <memory>

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/cudart_utils.h>
#include <raft/interruptible.hpp>
#include <raft/random/make_blobs.cuh>

#include <benchmark/benchmark.h>

Expand Down Expand Up @@ -121,6 +124,10 @@ class fixture {
// every benchmark should be overriding this
virtual void run_benchmark(::benchmark::State& state) = 0;
virtual void generate_metrics(::benchmark::State& state) {}
virtual void allocate_data(const ::benchmark::State& state) {}
virtual void deallocate_data(const ::benchmark::State& state) {}
virtual void allocate_temp_buffers(const ::benchmark::State& state) {}
virtual void deallocate_temp_buffers(const ::benchmark::State& state) {}

protected:
/** The helper that writes zeroes to some buffer in GPU memory to flush the L2 cache. */
Expand All @@ -144,6 +151,58 @@ class fixture {
}
};

/** Indicates the dataset size. */
struct DatasetParams {
size_t rows;
size_t cols;
bool row_major;
};

/** Holds params needed to generate blobs dataset */
struct BlobsParams {
int n_clusters;
double cluster_std;
bool shuffle;
double center_box_min, center_box_max;
uint64_t seed;
};

/** Fixture for cluster benchmarks using make_blobs */
template <typename T, typename IndexT = int>
class BlobsFixture : public fixture {
public:
BlobsFixture(const DatasetParams dp, const BlobsParams bp) : data_params(dp), blobs_params(bp) {}

virtual void run_benchmark(::benchmark::State& state) = 0;

void allocate_data(const ::benchmark::State& state) override
{
auto labels_ref = raft::make_device_vector<IndexT, IndexT>(this->handle, data_params.rows);
X = raft::make_device_matrix<T, IndexT>(this->handle, data_params.rows, data_params.cols);

raft::random::make_blobs<T, IndexT>(X.data_handle(),
labels_ref.data_handle(),
(IndexT)data_params.rows,
(IndexT)data_params.cols,
(IndexT)blobs_params.n_clusters,
stream,
data_params.row_major,
nullptr,
nullptr,
(T)blobs_params.cluster_std,
blobs_params.shuffle,
(T)blobs_params.center_box_min,
(T)blobs_params.center_box_max,
blobs_params.seed);
this->handle.sync_stream(stream);
}

protected:
DatasetParams data_params;
BlobsParams blobs_params;
raft::device_matrix<T, IndexT> X;
};

namespace internal {

template <typename Class, typename... Params>
Expand All @@ -162,8 +221,17 @@ class Fixture : public ::benchmark::Fixture {
{
fixture_ =
std::apply([](const Params&... ps) { return std::make_unique<Class>(ps...); }, params_);
fixture_->allocate_data(state);
fixture_->allocate_temp_buffers(state);
}

void TearDown(const State& state) override
{
fixture_->deallocate_temp_buffers(state);
fixture_->deallocate_data(state);
fixture_.reset();
}
void TearDown(const State& state) override { fixture_.reset(); }

void SetUp(State& st) override { SetUp(const_cast<const State&>(st)); }
void TearDown(State& st) override { TearDown(const_cast<const State&>(st)); }

Expand Down Expand Up @@ -248,6 +316,10 @@ struct registrar {

}; // namespace internal

#define RAFT_BENCH_REGISTER_INTERNAL(TestClass, ...) \
static raft::bench::internal::registrar<TestClass> BENCHMARK_PRIVATE_NAME(registrar)( \
RAFT_STRINGIFY(TestClass), __VA_ARGS__)

/**
* This is the entry point macro for all benchmarks. This needs to be called
* for the set of benchmarks to be registered so that the main harness inside
Expand All @@ -262,8 +334,7 @@ struct registrar {
* empty string
* @param params... zero or more lists of params upon which to benchmark.
*/
#define RAFT_BENCH_REGISTER(TestClass, ...) \
static raft::bench::internal::registrar<TestClass> BENCHMARK_PRIVATE_NAME(registrar)( \
#TestClass, __VA_ARGS__)
#define RAFT_BENCH_REGISTER(TestClass, ...) \
RAFT_BENCH_REGISTER_INTERNAL(RAFT_DEPAREN(TestClass), __VA_ARGS__)

} // namespace raft::bench
6 changes: 6 additions & 0 deletions cpp/bench/spatial/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
#include <raft/random/rng.cuh>

#include <raft/spatial/knn/ivf_flat.cuh>
#include <raft/spatial/knn/knn.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#endif

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.cuh>
#endif
Expand Down
Loading

0 comments on commit a7113c4

Please sign in to comment.