Skip to content

Commit

Permalink
Replace faiss bfKnn (#1202)
Browse files Browse the repository at this point in the history
Replace faiss bfKnn with code that leverages our pairwise_distance api and select_k api - by tiling over the inputs.

This lets us remove faiss as a dependency

Closes #798 
Closes #1159

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ray Douglass (https://github.com/raydouglass)

URL: #1202
  • Loading branch information
benfred authored Mar 17, 2023
1 parent e661035 commit 1a3c028
Show file tree
Hide file tree
Showing 40 changed files with 1,226 additions and 775 deletions.
2 changes: 1 addition & 1 deletion ci/checks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
re.compile(r"setup[.]cfg$"),
re.compile(r"meta[.]yaml$")
]
ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"]
ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"]

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/core/resource/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ inline cublasHandle_t get_cublas_handle(resources const& res)
cudaStream_t stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<cublas_resource_factory>(stream));
}
return *res.get_resource<cublasHandle_t>(resource_type::CUBLAS_HANDLE);
auto ret = *res.get_resource<cublasHandle_t>(resource_type::CUBLAS_HANDLE);
RAFT_CUBLAS_TRY(cublasSetStream(ret, get_cuda_stream(res)));
return ret;
};

/**
Expand Down
22 changes: 21 additions & 1 deletion cpp/include/raft/distance/distance_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,6 +66,26 @@ enum DistanceType : unsigned short {
Precomputed = 100
};

/**
* Whether minimal distance corresponds to similar elements (using the given metric).
*/
inline bool is_min_close(DistanceType metric)
{
bool select_min;
switch (metric) {
case DistanceType::InnerProduct:
case DistanceType::CosineExpanded:
case DistanceType::CorrelationExpanded:
// Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger
// similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362
// {perform_k_selection})
select_min = false;
break;
default: select_min = true;
}
return select_min;
}

namespace kernels {
enum KernelType { LINEAR, POLYNOMIAL, RBF, TANH };

Expand Down
3 changes: 3 additions & 0 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void transpose(raft::device_resources const& handle,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream));

int out_n_rows = n_cols;
int out_n_cols = n_rows;
Expand Down Expand Up @@ -90,6 +91,7 @@ void transpose_row_major_impl(
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
Expand All @@ -116,6 +118,7 @@ void transpose_col_major_impl(
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
Expand Down
51 changes: 25 additions & 26 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/knn_brute_force_faiss.cuh>
#include <raft/spatial/knn/detail/selection_faiss.cuh>

namespace raft::neighbors::brute_force {

Expand Down Expand Up @@ -96,15 +95,15 @@ inline void knn_merge_parts(
"Number of columns in output indices and distances matrices must be equal to k");

auto n_parts = in_keys.extent(0) / n_samples;
spatial::knn::detail::knn_merge_parts(in_keys.data_handle(),
in_values.data_handle(),
out_keys.data_handle(),
out_values.data_handle(),
n_samples,
n_parts,
in_keys.extent(1),
handle.get_stream(),
translations.value_or(nullptr));
detail::knn_merge_parts(in_keys.data_handle(),
in_values.data_handle(),
out_keys.data_handle(),
out_values.data_handle(),
n_samples,
n_parts,
in_keys.extent(1),
handle.get_stream(),
translations.value_or(nullptr));
}

/**
Expand Down Expand Up @@ -181,21 +180,21 @@ void knn(raft::device_resources const& handle,

std::vector<idx_t>* trans_arg = global_id_offset.has_value() ? &trans : nullptr;

raft::spatial::knn::detail::brute_force_knn_impl(handle,
inputs,
sizes,
static_cast<value_int>(index[0].extent(1)),
// TODO: This is unfortunate. Need to fix.
const_cast<value_t*>(search.data_handle()),
static_cast<value_int>(search.extent(0)),
indices.data_handle(),
distances.data_handle(),
k,
rowMajorIndex,
rowMajorQuery,
trans_arg,
metric,
metric_arg.value_or(2.0f));
raft::neighbors::detail::brute_force_knn_impl(handle,
inputs,
sizes,
static_cast<value_int>(index[0].extent(1)),
// TODO: This is unfortunate. Need to fix.
const_cast<value_t*>(search.data_handle()),
static_cast<value_int>(search.extent(0)),
indices.data_handle(),
distances.data_handle(),
k,
rowMajorIndex,
rowMajorQuery,
trans_arg,
metric,
metric_arg.value_or(2.0f));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <cuda.h>
#include <cuda_fp16.h>

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

template <typename T>
struct Comparator {
Expand All @@ -26,4 +26,4 @@ struct Comparator<half> {
__device__ static inline bool gt(half a, half b) { return __hgt(a, b); }
};

} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
52 changes: 52 additions & 0 deletions cpp/include/raft/neighbors/detail/faiss_select/DistanceUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file thirdparty/LICENSES/LICENSE.faiss
*/

#pragma once

namespace raft::neighbors::detail::faiss_select {
// If the inner size (dim) of the vectors is small, we want a larger query tile
// size, like 1024
inline void chooseTileSize(size_t numQueries,
size_t numCentroids,
size_t dim,
size_t elementSize,
size_t totalMem,
size_t& tileRows,
size_t& tileCols)
{
// The matrix multiplication should be large enough to be efficient, but if
// it is too large, we seem to lose efficiency as opposed to
// double-streaming. Each tile size here defines 1/2 of the memory use due
// to double streaming. We ignore available temporary memory, as that is
// adjusted independently by the user and can thus meet these requirements
// (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs,
// prefer 768 MB of usage. Otherwise, prefer 1 GB of usage.
size_t targetUsage = 0;

if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) {
targetUsage = 512 * 1024 * 1024;
} else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) {
targetUsage = 768 * 1024 * 1024;
} else {
targetUsage = 1024 * 1024 * 1024;
}

targetUsage /= 2 * elementSize;

// 512 seems to be a batch size sweetspot for float32.
// If we are on float16, increase to 512.
// If the k size (vec dim) of the matrix multiplication is small (<= 32),
// increase to 1024.
size_t preferredTileRows = 512;
if (dim <= 32) { preferredTileRows = 1024; }

tileRows = std::min(preferredTileRows, numQueries);

// tileCols is the remainder size
tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
}
} // namespace raft::neighbors::detail::faiss_select
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#pragma once

#include <cuda.h>
#include <raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/spatial/knn/detail/faiss_select/StaticUtils.h>
#include <raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/neighbors/detail/faiss_select/StaticUtils.h>

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

// Merge pairs of lists smaller than blockDim.x (NumThreads)
template <int NumThreads,
Expand Down Expand Up @@ -274,4 +274,4 @@ inline __device__ void blockMerge(K* listK, V* listV)
BlockMerge<NumThreads, K, V, N, L, Dir, Comp, kSmallerThanBlock, FullMerge>::merge(listK, listV);
}

} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#pragma once

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

template <typename T>
inline __device__ void swap(bool swap, T& x, T& y)
Expand All @@ -22,4 +22,4 @@ inline __device__ void assign(bool assign, T& x, T y)
{
x = assign ? y : x;
}
} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

#pragma once

#include <raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/spatial/knn/detail/faiss_select/StaticUtils.h>
#include <raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/neighbors/detail/faiss_select/StaticUtils.h>

#include <raft/util/cuda_utils.cuh>

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

//
// This file contains functions to:
Expand Down Expand Up @@ -518,4 +518,4 @@ inline __device__ void warpSortAnyRegisters(K k[N], V v[N])
BitonicSortStep<K, V, N, Dir, Comp>::sort(k, v);
}

} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

#pragma once

#include <raft/spatial/knn/detail/faiss_select/Comparators.cuh>
#include <raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh>
#include <raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh>
#include <raft/neighbors/detail/faiss_select/Comparators.cuh>
#include <raft/neighbors/detail/faiss_select/MergeNetworkBlock.cuh>
#include <raft/neighbors/detail/faiss_select/MergeNetworkWarp.cuh>

#include <raft/core/kvp.hpp>
#include <raft/util/cuda_utils.cuh>

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

// Specialization for block-wide monotonic merges producing a merge sort
// since what we really want is a constexpr loop expansion
Expand Down Expand Up @@ -552,4 +552,4 @@ struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
V threadV;
};

} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#define __device__
#endif

namespace raft::spatial::knn::detail::faiss_select::utils {
namespace raft::neighbors::detail::faiss_select::utils {

template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v)
Expand Down Expand Up @@ -45,4 +45,4 @@ static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPower
static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
"nextHighestPowerOf2");

} // namespace raft::spatial::knn::detail::faiss_select::utils
} // namespace raft::neighbors::detail::faiss_select::utils
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

#pragma once

#include <raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/spatial/knn/detail/faiss_select/Select.cuh>
#include <raft/neighbors/detail/faiss_select/MergeNetworkUtils.cuh>
#include <raft/neighbors/detail/faiss_select/Select.cuh>

// TODO: Need to think further about the impact (and new boundaries created) on the registers
// because this will change the max k that can be processed. One solution might be to break
// up k into multiple batches for larger k.

namespace raft::spatial::knn::detail::faiss_select {
namespace raft::neighbors::detail::faiss_select {

// `Dir` true, produce largest values.
// `Dir` false, produce smallest values.
Expand Down Expand Up @@ -221,4 +221,4 @@ struct KeyValueBlockSelect {
int kMinus1;
};

} // namespace raft::spatial::knn::detail::faiss_select
} // namespace raft::neighbors::detail::faiss_select
22 changes: 1 addition & 21 deletions cpp/include/raft/neighbors/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1244,26 +1244,6 @@ void search_impl(raft::device_resources const& handle,
}
}

/**
* Whether minimal distance corresponds to similar elements (using the given metric).
*/
inline bool is_min_close(distance::DistanceType metric)
{
bool select_min;
switch (metric) {
case raft::distance::DistanceType::InnerProduct:
case raft::distance::DistanceType::CosineExpanded:
case raft::distance::DistanceType::CorrelationExpanded:
// Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger
// similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362
// {perform_k_selection})
select_min = false;
break;
default: select_min = true;
}
return select_min;
}

/** See raft::neighbors::ivf_flat::search docs */
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
Expand Down Expand Up @@ -1295,7 +1275,7 @@ inline void search(raft::device_resources const& handle,
n_queries,
k,
n_probes,
is_min_close(index.metric()),
raft::distance::is_min_close(index.metric()),
neighbors,
distances,
mr);
Expand Down
Loading

0 comments on commit 1a3c028

Please sign in to comment.