forked from rapidsai/raft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds the sparseL2NN functionality. This enables faster computing pairwise distances by making use of sparsity in the problem: the computation of distances between point pairs can be skipped. The sparsity between arrays of points X and Y is expressed as follows: - X is split into rows (points) - Y is split into contiguous groups of points (i.e. all points in a group are adjacent) - A boolean adjacency matrix indicates for each row of X and each group in Y whether to compute the distance. To speed up computation, the adjacency matrix is compressed into a bitfield. To ensure competitive speeds, the caller must make sure that consecutive rows in X are adjacent to the same groups in Y (as much as possible) to enable efficient skipping in the kernel. Some work is still TODO: - Flesh out documentation - Discuss / remove allocation of intermediate array - Optimize for skinny matrices by using a different KernelPolicy. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#838
- Loading branch information
Allard Hendriksen
authored
Jan 27, 2023
1 parent
c58d00a
commit 2fb5c06
Showing
14 changed files
with
1,923 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cstdint> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <sstream> | ||
#include <string> | ||
|
||
#include <common/benchmark.hpp> | ||
#include <limits> | ||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/distance/masked_nn.cuh> | ||
#include <raft/handle.hpp> | ||
#include <raft/linalg/norm.cuh> | ||
#include <raft/random/rng.cuh> | ||
#include <raft/util/cudart_utils.hpp> | ||
|
||
#if defined RAFT_NN_COMPILED | ||
#include <raft/spatial/knn/specializations.hpp> | ||
#endif | ||
|
||
namespace raft::bench::distance::masked_nn { | ||
|
||
// Introduce various sparsity patterns | ||
enum AdjacencyPattern { | ||
checkerboard = 0, | ||
checkerboard_4 = 1, | ||
checkerboard_64 = 2, | ||
all_true = 3, | ||
all_false = 4 | ||
}; | ||
|
||
struct Params { | ||
int m, n, k, num_groups; | ||
AdjacencyPattern pattern; | ||
}; // struct Params | ||
|
||
__global__ void init_adj(AdjacencyPattern pattern, | ||
int n, | ||
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj, | ||
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs) | ||
{ | ||
int m = adj.extent(0); | ||
int num_groups = adj.extent(1); | ||
|
||
for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; | ||
idx_m += blockDim.y * gridDim.y) { | ||
for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; | ||
idx_g += blockDim.x * gridDim.x) { | ||
switch (pattern) { | ||
case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; | ||
case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; | ||
case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; | ||
case all_true: adj(idx_m, idx_g) = true; break; | ||
case all_false: adj(idx_m, idx_g) = false; break; | ||
default: assert(false && "unknown pattern"); | ||
} | ||
} | ||
} | ||
// Each group is of size n / num_groups. | ||
// | ||
// - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive | ||
// scan of the group lengths) | ||
// | ||
// - The first group always starts at index zero, so we do not store it. | ||
// | ||
// - The group_idxs[num_groups - 1] should always equal n. | ||
|
||
if (blockIdx.y == 0 && threadIdx.y == 0) { | ||
const int g_stride = blockDim.x * gridDim.x; | ||
for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { | ||
group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); | ||
} | ||
group_idxs(num_groups - 1) = n; | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct masked_l2_nn : public fixture { | ||
using DataT = T; | ||
using IdxT = int; | ||
using OutT = raft::KeyValuePair<IdxT, DataT>; | ||
using RedOpT = raft::distance::MinAndDistanceReduceOp<int, DataT>; | ||
using PairRedOpT = raft::distance::KVPMinReduce<int, DataT>; | ||
using ParamT = raft::distance::MaskedL2NNParams<RedOpT, PairRedOpT>; | ||
|
||
// Parameters | ||
Params params; | ||
// Data | ||
raft::device_vector<OutT, IdxT> out; | ||
raft::device_matrix<T, IdxT> x, y; | ||
raft::device_vector<DataT, IdxT> xn, yn; | ||
raft::device_matrix<bool, IdxT> adj; | ||
raft::device_vector<IdxT, IdxT> group_idxs; | ||
|
||
masked_l2_nn(const Params& p) | ||
: params(p), | ||
out{raft::make_device_vector<OutT, IdxT>(handle, p.m)}, | ||
x{raft::make_device_matrix<DataT, IdxT>(handle, p.m, p.k)}, | ||
y{raft::make_device_matrix<DataT, IdxT>(handle, p.n, p.k)}, | ||
xn{raft::make_device_vector<DataT, IdxT>(handle, p.m)}, | ||
yn{raft::make_device_vector<DataT, IdxT>(handle, p.n)}, | ||
adj{raft::make_device_matrix<bool, IdxT>(handle, p.m, p.num_groups)}, | ||
group_idxs{raft::make_device_vector<IdxT, IdxT>(handle, p.num_groups)} | ||
{ | ||
raft::random::RngState r(123456ULL); | ||
|
||
uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); | ||
uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); | ||
raft::linalg::rowNorm( | ||
xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); | ||
raft::linalg::rowNorm( | ||
yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); | ||
raft::distance::initialize<T, raft::KeyValuePair<int, T>, int>( | ||
handle, out.data_handle(), p.m, std::numeric_limits<T>::max(), RedOpT{}); | ||
|
||
dim3 block(32, 32); | ||
dim3 grid(10, 10); | ||
init_adj<<<grid, block, 0, stream>>>(p.pattern, p.n, adj.view(), group_idxs.view()); | ||
RAFT_CUDA_TRY(cudaGetLastError()); | ||
} | ||
|
||
void run_benchmark(::benchmark::State& state) override | ||
{ | ||
bool init_out = true; | ||
bool sqrt = false; | ||
ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; | ||
|
||
loop_on_state(state, [this, masked_l2_params]() { | ||
// It is sufficient to only benchmark the L2-squared metric | ||
raft::distance::maskedL2NN<DataT, OutT, IdxT>(handle, | ||
masked_l2_params, | ||
x.view(), | ||
y.view(), | ||
xn.view(), | ||
yn.view(), | ||
adj.view(), | ||
group_idxs.view(), | ||
out.view()); | ||
}); | ||
|
||
// Virtual flop count if no skipping had occurred. | ||
size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); | ||
|
||
int64_t read_elts = params.n * params.k + params.m * params.k; | ||
int64_t write_elts = params.m; | ||
|
||
// Virtual min flops is the number of flops that would have been executed if | ||
// the algorithm had actually skipped each computation that it could have | ||
// skipped. | ||
size_t virtual_min_flops = 0; | ||
switch (params.pattern) { | ||
case checkerboard: | ||
case checkerboard_4: | ||
case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; | ||
case all_true: virtual_min_flops = virtual_flops; break; | ||
case all_false: virtual_min_flops = 0; break; | ||
default: assert(false && "unknown pattern"); | ||
} | ||
|
||
// VFLOP/s is the "virtual" flop count that would have executed if there was | ||
// no adjacency pattern. This is useful for comparing to fusedL2NN | ||
state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, | ||
benchmark::Counter::kIsIterationInvariantRate, | ||
benchmark::Counter::OneK::kIs1000); | ||
// Virtual min flops is the number of flops that would have been executed if | ||
// the algorithm had actually skipped each computation that it could have | ||
// skipped. | ||
state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, | ||
benchmark::Counter::kIsIterationInvariantRate, | ||
benchmark::Counter::OneK::kIs1000); | ||
|
||
state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), | ||
benchmark::Counter::kIsIterationInvariantRate, | ||
benchmark::Counter::OneK::kIs1000); | ||
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), | ||
benchmark::Counter::kIsIterationInvariantRate, | ||
benchmark::Counter::OneK::kIs1000); | ||
|
||
state.counters["m"] = benchmark::Counter(params.m); | ||
state.counters["n"] = benchmark::Counter(params.n); | ||
state.counters["k"] = benchmark::Counter(params.k); | ||
state.counters["num_groups"] = benchmark::Counter(params.num_groups); | ||
state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); | ||
state.counters["Pat"] = benchmark::Counter(static_cast<int>(params.pattern)); | ||
|
||
state.counters["SM count"] = raft::getMultiProcessorCount(); | ||
} | ||
}; // struct MaskedL2NN | ||
|
||
const std::vector<Params> masked_l2_nn_input_vecs = { | ||
// Very fat matrices... | ||
{32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, | ||
|
||
// Representative matrices... | ||
{16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, | ||
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, | ||
|
||
{16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, | ||
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, | ||
|
||
{16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, | ||
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, | ||
|
||
{16384, 16384, 32, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 64, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 128, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 256, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 512, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 1024, 32, AdjacencyPattern::all_true}, | ||
{16384, 16384, 16384, 32, AdjacencyPattern::all_true}, | ||
|
||
{16384, 16384, 32, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 64, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 128, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 256, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 512, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 1024, 32, AdjacencyPattern::all_false}, | ||
{16384, 16384, 16384, 32, AdjacencyPattern::all_false}, | ||
}; | ||
|
||
RAFT_BENCH_REGISTER(masked_l2_nn<float>, "", masked_l2_nn_input_vecs); | ||
// We don't benchmark double to keep compile times in check when not using the | ||
// distance library. | ||
|
||
} // namespace raft::bench::distance::masked_nn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <raft/core/handle.hpp> | ||
#include <raft/util/cuda_utils.cuh> | ||
#include <raft/util/device_atomics.cuh> | ||
|
||
namespace raft::distance::detail { | ||
|
||
/** | ||
* @brief Compress 2D boolean matrix to bitfield | ||
* | ||
* Utility kernel for maskedL2NN. | ||
* | ||
* @tparam T | ||
* | ||
* @parameter[in] in An `m x n` boolean matrix. Row major. | ||
* @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of | ||
* type T, where T is of size `bits_per_elem` bits. | ||
* Note: the division (`/`) is a ceilDiv. | ||
*/ | ||
template <typename T = uint64_t, typename = std::enable_if_t<std::is_integral<T>::value>> | ||
__global__ void compress_to_bits_kernel( | ||
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in, | ||
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out) | ||
{ | ||
constexpr int bits_per_element = 8 * sizeof(T); | ||
constexpr int tile_dim_m = bits_per_element; | ||
constexpr int nthreads = 128; | ||
constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector | ||
|
||
// Tile in shared memory is transposed | ||
__shared__ bool smem[tile_dim_n][tile_dim_m]; | ||
|
||
const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); | ||
const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); | ||
|
||
for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { | ||
const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); | ||
const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); | ||
|
||
if (in.extent(0) <= tile_idx_m) { break; } | ||
// Fill shared memory tile | ||
bool reg_buf[tile_dim_m]; | ||
#pragma unroll | ||
for (int i = 0; i < tile_dim_m; ++i) { | ||
const int in_m = tile_idx_m + i; | ||
const int in_n = tile_idx_n + threadIdx.x; | ||
bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); | ||
reg_buf[i] = in_bounds ? in(in_m, in_n) : false; | ||
smem[threadIdx.x][i] = reg_buf[i]; | ||
} | ||
__syncthreads(); | ||
|
||
// Drain memory tile into single output element out_elem. | ||
T out_elem{0}; | ||
#pragma unroll | ||
for (int j = 0; j < tile_dim_n; ++j) { | ||
if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } | ||
} | ||
__syncthreads(); | ||
|
||
// Write output. | ||
int out_m = tile_idx_m / bits_per_element; | ||
int out_n = tile_idx_n + threadIdx.x; | ||
|
||
if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } | ||
} | ||
} | ||
|
||
/** | ||
* @brief Compress 2D boolean matrix to bitfield | ||
* | ||
* Utility kernel for maskedL2NN. | ||
* | ||
* @tparam T | ||
* | ||
* @parameter[in] in An `m x n` boolean matrix. Row major. | ||
* @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of | ||
* type T, where T is of size `bits_per_elem` bits. | ||
* Note: the division (`/`) is a ceilDiv. | ||
*/ | ||
template <typename T = uint64_t, typename = std::enable_if_t<std::is_integral<T>::value>> | ||
void compress_to_bits(raft::device_resources const& handle, | ||
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in, | ||
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out) | ||
{ | ||
auto stream = handle.get_stream(); | ||
constexpr int bits_per_element = 8 * sizeof(T); | ||
|
||
RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), | ||
"Number of output rows must be ceildiv(input rows, bits_per_elem)"); | ||
RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); | ||
|
||
const int num_SMs = raft::getMultiProcessorCount(); | ||
int blocks_per_sm = 0; | ||
constexpr int num_threads = 128; | ||
constexpr int dyn_smem_size = 0; | ||
RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||
&blocks_per_sm, compress_to_bits_kernel<T>, num_threads, dyn_smem_size)); | ||
|
||
dim3 grid(num_SMs * blocks_per_sm); | ||
dim3 block(128); | ||
compress_to_bits_kernel<<<grid, block, 0, stream>>>(in, out); | ||
RAFT_CUDA_TRY(cudaGetLastError()); | ||
} | ||
|
||
}; // namespace raft::distance::detail |
Oops, something went wrong.