-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create cub-based argmin primitive and replace
argmin_along_rows
in …
…ANN kmeans (#912) This PR follows up on [a suggestion](#821 (comment)) from @cjnolet. The new `argmin` primitive is up to 5x faster than `argmin_along_rows` for dimensions relevant to ANN kmeans, and removes code duplication. The reasons why it is faster are: - `argmin_along_rows` often misses on doing a sequential reduction before the tree reduction, especially as it uses large block sizes, as much as 1024. - CUB has a better reduction algorithm than the basic shared-mem reduction used in `argmin_along_rows`. - If we switch the `argmin` prim to using the `cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY` algorithm, we can get up to 30% further speedup! (I believe it's safe to use the commutative algorithm here since the offset is contained in the key-value pair so the reduction operation is commutative). The speedup that I have measured for IVF-Flat build with the `InnerProduct` metric is around 15%. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: #912
- Loading branch information
Showing
12 changed files
with
300 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <common/benchmark.hpp> | ||
#include <raft/matrix/argmin.cuh> | ||
#include <raft/random/rng.cuh> | ||
|
||
#include <rmm/device_uvector.hpp> | ||
|
||
namespace raft::bench::linalg { | ||
|
||
template <typename IdxT> | ||
struct ArgminParams { | ||
IdxT rows, cols; | ||
}; | ||
|
||
template <typename T, typename OutT, typename IdxT> | ||
struct Argmin : public fixture { | ||
Argmin(const ArgminParams<IdxT>& p) : params(p) {} | ||
|
||
void allocate_data(const ::benchmark::State& state) override | ||
{ | ||
matrix = raft::make_device_matrix<T, IdxT>(handle, params.rows, params.cols); | ||
indices = raft::make_device_vector<OutT, IdxT>(handle, params.rows); | ||
|
||
raft::random::RngState rng{1234}; | ||
raft::random::uniform( | ||
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); | ||
handle.sync_stream(stream); | ||
} | ||
|
||
void run_benchmark(::benchmark::State& state) override | ||
{ | ||
loop_on_state(state, [this]() { | ||
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>( | ||
matrix.data_handle(), matrix.extent(0), matrix.extent(1)); | ||
raft::matrix::argmin(handle, matrix_const_view, indices.view()); | ||
}); | ||
} | ||
|
||
private: | ||
ArgminParams<IdxT> params; | ||
raft::device_matrix<T, IdxT> matrix; | ||
raft::device_vector<OutT, IdxT> indices; | ||
}; // struct Argmin | ||
|
||
const std::vector<ArgminParams<int64_t>> argmin_inputs_i64{ | ||
{1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, | ||
{10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, | ||
{100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, | ||
{1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, | ||
{10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, | ||
}; | ||
|
||
RAFT_BENCH_REGISTER((Argmin<float, uint32_t, int64_t>), "", argmin_inputs_i64); | ||
RAFT_BENCH_REGISTER((Argmin<double, uint32_t, int64_t>), "", argmin_inputs_i64); | ||
|
||
} // namespace raft::bench::linalg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/matrix/detail/math.cuh> | ||
|
||
namespace raft::matrix { | ||
|
||
/** | ||
* @brief Argmin: find the col idx with minimum value for each row | ||
* @param[in] handle: raft handle | ||
* @param[in] in: input matrix of size (n_rows, n_cols) | ||
* @param[out] out: output vector of size n_rows | ||
*/ | ||
template <typename math_t, typename idx_t, typename matrix_idx_t> | ||
void argmin(const raft::handle_t& handle, | ||
raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in, | ||
raft::device_vector_view<idx_t, matrix_idx_t> out) | ||
{ | ||
RAFT_EXPECTS(out.extent(0) == in.extent(0), | ||
"Size of output vector must equal number of rows in input matrix."); | ||
detail::argmin( | ||
in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); | ||
} | ||
} // namespace raft::matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.