Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MaskedL2NN #838

Merged
merged 57 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
a3b8587
contractions: Concentrate tile index calculations
ahendriksen Sep 2, 2022
99e65a5
pairwise_distance_base: Remove all ldgXY(0) calls
ahendriksen Sep 2, 2022
e6d5078
pairwise_distance_base: Move all logic into run loop
ahendriksen Sep 2, 2022
995d2ae
pairwise_distance_base: Fix typo
ahendriksen Oct 5, 2022
e6976c5
Implement reviewer feedback
ahendriksen Jan 23, 2023
4947dc8
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 25, 2023
8385f2f
Add sparseL2NN initial implementation
ahendriksen Sep 22, 2022
69687d6
Add sparseL2NN benchmarks
ahendriksen Sep 22, 2022
e74b5f4
Rename files sparse_* => masked_*
ahendriksen Jan 23, 2023
7bf5801
Rename functions sparse_* => masked_*
ahendriksen Jan 23, 2023
96ff79d
Remove workspace parameter
ahendriksen Jan 23, 2023
be56826
Use mdspan parameters
ahendriksen Jan 23, 2023
07c2f53
Remove uvectors from benchmark
ahendriksen Jan 23, 2023
2f23481
Use parameter struct in public API
ahendriksen Jan 23, 2023
7df1fd2
Clean up minor nits in tests and benchmarks
ahendriksen Jan 23, 2023
b590684
Update copyright years
ahendriksen Jan 23, 2023
9fcc24d
Rename masked_l2_nn => masked_nn
ahendriksen Jan 23, 2023
953b33d
cmake: rename masked_l2_nn => masked_nn
ahendriksen Jan 23, 2023
724e874
Clang-format
ahendriksen Jan 23, 2023
4ea41c4
Fix docs for masked NN
ahendriksen Jan 23, 2023
529764e
Reduce iterations of deterministic test
ahendriksen Jan 24, 2023
f4db5e5
Add SDDM comparison
ahendriksen Jan 24, 2023
b0954e8
Fix const and add extents checks
ahendriksen Jan 24, 2023
2e3ca44
Refactor maskedL2NN test
ahendriksen Jan 24, 2023
db6288c
wording: grouped -> processed
ahendriksen Jan 24, 2023
773ab2f
Docstring changes and removal of extraneous function
ahendriksen Jan 24, 2023
bae2bcb
Move sqrt from template to run-time parameter
ahendriksen Jan 24, 2023
0629446
Remove extraneous memcpy::async stuff
ahendriksen Jan 24, 2023
33787ce
Implement half the reviewer feedback on mnn_base
ahendriksen Jan 24, 2023
6243c8a
Formatting
ahendriksen Jan 24, 2023
4dc0c37
Fix copyright years
ahendriksen Jan 24, 2023
b84a5b4
Reword comment
ahendriksen Jan 25, 2023
7fb7cc9
test: Remove redundant comparison operator
ahendriksen Jan 25, 2023
fe31a13
Add sphinx docsfor maskedL2NNImpl
ahendriksen Jan 25, 2023
3bf204f
Document thread_adj
ahendriksen Jan 25, 2023
08448d8
Rename tile_end_n -> group_end_n
ahendriksen Jan 25, 2023
ed178ca
Formatting
ahendriksen Jan 25, 2023
80cb3b0
Add compress_to_bits kernel wrapper
ahendriksen Jan 25, 2023
546854d
Add compress_to_bits docs and test
ahendriksen Jan 25, 2023
125ae98
masked_nn: Fix test param printer
ahendriksen Jan 25, 2023
6e77653
Use mdspan for compress_to_bits
ahendriksen Jan 25, 2023
8925146
Make benchmark more informative
ahendriksen Jan 25, 2023
3e86024
Merge branch 'branch-23.02' into enh-sparse-l2-nn
cjnolet Jan 25, 2023
ba6491a
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 25, 2023
e52b0f9
Forcing sccache reinit.
cjnolet Jan 26, 2023
34eb76a
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 26, 2023
85c6294
Breaking specializations for refine into individual files
cjnolet Jan 26, 2023
0fad842
Checking in
cjnolet Jan 26, 2023
f7788af
Including just the refine specialization
cjnolet Jan 26, 2023
e626101
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 26, 2023
9e7b729
Proper import of speicalizations
cjnolet Jan 26, 2023
9e4b5f3
Merge branch 'wip-move-contractions-tiling-logic' of github.com:ahend…
cjnolet Jan 26, 2023
060e62c
Remove SCCACHE_RECACHE from build.sh
cjnolet Jan 26, 2023
2870b67
Merge branch 'wip-move-contractions-tiling-logic' into enh-sparse-l2-nn
cjnolet Jan 26, 2023
2b0c02b
Small compilation error remains
cjnolet Jan 26, 2023
1e83640
Take device_resources instead of handle
ahendriksen Jan 27, 2023
862e8f6
Merge remote-tracking branch 'rapids/branch-23.02' into enh-sparse-l2-nn
ahendriksen Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ if(BUILD_BENCH)
bench/distance/distance_l1.cu
bench/distance/distance_unexp_l2.cu
bench/distance/fused_l2_nn.cu
bench/distance/masked_nn.cu
bench/distance/kernels.cu
bench/main.cpp
OPTIONAL
Expand Down
267 changes: 267 additions & 0 deletions cpp/bench/distance/masked_nn.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>

#include <common/benchmark.hpp>
#include <limits>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/distance/masked_nn.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cudart_utils.hpp>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.hpp>
#endif

namespace raft::bench::distance::masked_nn {

// Introduce various sparsity patterns
enum AdjacencyPattern {
checkerboard = 0,
checkerboard_4 = 1,
checkerboard_64 = 2,
all_true = 3,
all_false = 4
};

struct Params {
int m, n, k, num_groups;
AdjacencyPattern pattern;
}; // struct Params

__global__ void init_adj(AdjacencyPattern pattern,
int n,
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj,
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs)
{
int m = adj.extent(0);
int num_groups = adj.extent(1);

for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m;
idx_m += blockDim.y * gridDim.y) {
for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups;
idx_g += blockDim.x * gridDim.x) {
switch (pattern) {
case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break;
case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break;
case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break;
case all_true: adj(idx_m, idx_g) = true; break;
case all_false: adj(idx_m, idx_g) = false; break;
default: assert(false && "unknown pattern");
}
}
}
// Each group is of size n / num_groups.
//
// - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive
// scan of the group lengths)
//
// - The first group always starts at index zero, so we do not store it.
//
// - The group_idxs[num_groups - 1] should always equal n.

if (blockIdx.y == 0 && threadIdx.y == 0) {
const int g_stride = blockDim.x * gridDim.x;
for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) {
group_idxs(idx_g) = (idx_g + 1) * (n / num_groups);
}
group_idxs(num_groups - 1) = n;
}
}

template <typename T>
struct masked_l2_nn : public fixture {
using DataT = T;
using IdxT = int;
using OutT = raft::KeyValuePair<IdxT, DataT>;
using RedOpT = raft::distance::MinAndDistanceReduceOp<int, DataT>;
using PairRedOpT = raft::distance::KVPMinReduce<int, DataT>;
using ParamT = raft::distance::MaskedL2NNParams<RedOpT, PairRedOpT>;

// Parameters
Params params;
// Data
raft::device_vector<OutT, IdxT> out;
raft::device_matrix<T, IdxT> x, y;
raft::device_vector<DataT, IdxT> xn, yn;
raft::device_matrix<bool, IdxT> adj;
raft::device_vector<IdxT, IdxT> group_idxs;

masked_l2_nn(const Params& p)
: params(p),
out{raft::make_device_vector<OutT, IdxT>(handle, p.m)},
x{raft::make_device_matrix<DataT, IdxT>(handle, p.m, p.k)},
y{raft::make_device_matrix<DataT, IdxT>(handle, p.n, p.k)},
xn{raft::make_device_vector<DataT, IdxT>(handle, p.m)},
yn{raft::make_device_vector<DataT, IdxT>(handle, p.n)},
adj{raft::make_device_matrix<bool, IdxT>(handle, p.m, p.num_groups)},
group_idxs{raft::make_device_vector<IdxT, IdxT>(handle, p.num_groups)}
{
raft::random::RngState r(123456ULL);

uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0));
uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0));
raft::linalg::rowNorm(
xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(
yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream);
raft::distance::initialize<T, raft::KeyValuePair<int, T>, int>(
handle, out.data_handle(), p.m, std::numeric_limits<T>::max(), RedOpT{});

dim3 block(32, 32);
dim3 grid(10, 10);
init_adj<<<grid, block, 0, stream>>>(p.pattern, p.n, adj.view(), group_idxs.view());
RAFT_CUDA_TRY(cudaGetLastError());
}

void run_benchmark(::benchmark::State& state) override
{
bool init_out = true;
bool sqrt = false;
ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out};

loop_on_state(state, [this, masked_l2_params]() {
// It is sufficient to only benchmark the L2-squared metric
raft::distance::maskedL2NN<DataT, OutT, IdxT>(handle,
masked_l2_params,
x.view(),
y.view(),
xn.view(),
yn.view(),
adj.view(),
group_idxs.view(),
out.view());
});

// Virtual flop count if no skipping had occurred.
size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k);

int64_t read_elts = params.n * params.k + params.m * params.k;
int64_t write_elts = params.m;

// Virtual min flops is the number of flops that would have been executed if
// the algorithm had actually skipped each computation that it could have
// skipped.
size_t virtual_min_flops = 0;
switch (params.pattern) {
case checkerboard:
case checkerboard_4:
case checkerboard_64: virtual_min_flops = virtual_flops / 2; break;
case all_true: virtual_min_flops = virtual_flops; break;
case all_false: virtual_min_flops = 0; break;
default: assert(false && "unknown pattern");
}

// VFLOP/s is the "virtual" flop count that would have executed if there was
// no adjacency pattern. This is useful for comparing to fusedL2NN
state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
// Virtual min flops is the number of flops that would have been executed if
// the algorithm had actually skipped each computation that it could have
// skipped.
state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["m"] = benchmark::Counter(params.m);
state.counters["n"] = benchmark::Counter(params.n);
state.counters["k"] = benchmark::Counter(params.k);
state.counters["num_groups"] = benchmark::Counter(params.num_groups);
state.counters["group size"] = benchmark::Counter(params.n / params.num_groups);
state.counters["Pat"] = benchmark::Counter(static_cast<int>(params.pattern));

state.counters["SM count"] = raft::getMultiProcessorCount();
}
}; // struct MaskedL2NN

const std::vector<Params> masked_l2_nn_input_vecs = {
// Very fat matrices...
{32, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{64, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{128, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{256, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{512, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{1024, 16384, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 32, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 64, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 128, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 256, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 512, 16384, 32, AdjacencyPattern::checkerboard},
{16384, 1024, 16384, 32, AdjacencyPattern::checkerboard},

// Representative matrices...
{16384, 16384, 32, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard},
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard},

{16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4},
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4},

{16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64},
{16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64},

{16384, 16384, 32, 32, AdjacencyPattern::all_true},
{16384, 16384, 64, 32, AdjacencyPattern::all_true},
{16384, 16384, 128, 32, AdjacencyPattern::all_true},
{16384, 16384, 256, 32, AdjacencyPattern::all_true},
{16384, 16384, 512, 32, AdjacencyPattern::all_true},
{16384, 16384, 1024, 32, AdjacencyPattern::all_true},
{16384, 16384, 16384, 32, AdjacencyPattern::all_true},

{16384, 16384, 32, 32, AdjacencyPattern::all_false},
{16384, 16384, 64, 32, AdjacencyPattern::all_false},
{16384, 16384, 128, 32, AdjacencyPattern::all_false},
{16384, 16384, 256, 32, AdjacencyPattern::all_false},
{16384, 16384, 512, 32, AdjacencyPattern::all_false},
{16384, 16384, 1024, 32, AdjacencyPattern::all_false},
{16384, 16384, 16384, 32, AdjacencyPattern::all_false},
};

RAFT_BENCH_REGISTER(masked_l2_nn<float>, "", masked_l2_nn_input_vecs);
// We don't benchmark double to keep compile times in check when not using the
// distance library.

} // namespace raft::bench::distance::masked_nn
122 changes: 122 additions & 0 deletions cpp/include/raft/distance/detail/compress_to_bits.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/handle.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/device_atomics.cuh>

namespace raft::distance::detail {

/**
* @brief Compress 2D boolean matrix to bitfield
*
* Utility kernel for maskedL2NN.
*
* @tparam T
*
* @parameter[in] in An `m x n` boolean matrix. Row major.
* @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of
* type T, where T is of size `bits_per_elem` bits.
* Note: the division (`/`) is a ceilDiv.
*/
template <typename T = uint64_t, typename = std::enable_if_t<std::is_integral<T>::value>>
__global__ void compress_to_bits_kernel(
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in,
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out)
{
constexpr int bits_per_element = 8 * sizeof(T);
constexpr int tile_dim_m = bits_per_element;
constexpr int nthreads = 128;
constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector

// Tile in shared memory is transposed
__shared__ bool smem[tile_dim_n][tile_dim_m];

const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m);
const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n);

for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) {
const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n);
const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n);

if (in.extent(0) <= tile_idx_m) { break; }
// Fill shared memory tile
bool reg_buf[tile_dim_m];
#pragma unroll
for (int i = 0; i < tile_dim_m; ++i) {
const int in_m = tile_idx_m + i;
const int in_n = tile_idx_n + threadIdx.x;
bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1);
reg_buf[i] = in_bounds ? in(in_m, in_n) : false;
smem[threadIdx.x][i] = reg_buf[i];
}
__syncthreads();

// Drain memory tile into single output element out_elem.
T out_elem{0};
#pragma unroll
for (int j = 0; j < tile_dim_n; ++j) {
if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; }
}
__syncthreads();

// Write output.
int out_m = tile_idx_m / bits_per_element;
int out_n = tile_idx_n + threadIdx.x;

if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; }
}
}

/**
* @brief Compress 2D boolean matrix to bitfield
*
* Utility kernel for maskedL2NN.
*
* @tparam T
*
* @parameter[in] in An `m x n` boolean matrix. Row major.
* @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of
* type T, where T is of size `bits_per_elem` bits.
* Note: the division (`/`) is a ceilDiv.
*/
template <typename T = uint64_t, typename = std::enable_if_t<std::is_integral<T>::value>>
void compress_to_bits(const raft::handle_t& handle,
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in,
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out)
{
auto stream = handle.get_stream();
constexpr int bits_per_element = 8 * sizeof(T);

RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0),
"Number of output rows must be ceildiv(input rows, bits_per_elem)");
RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns.");

const int num_SMs = raft::getMultiProcessorCount();
int blocks_per_sm = 0;
constexpr int num_threads = 128;
constexpr int dyn_smem_size = 0;
RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm, compress_to_bits_kernel<T>, num_threads, dyn_smem_size));

dim3 grid(num_SMs * blocks_per_sm);
dim3 block(128);
compress_to_bits_kernel<<<grid, block, 0, stream>>>(in, out);
RAFT_CUDA_TRY(cudaGetLastError());
}

}; // namespace raft::distance::detail
Loading