Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Update KNN #171

Merged
merged 4 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 189 additions & 58 deletions cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,27 +166,140 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK,
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) {
inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case distance::DistanceType::L1:
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case distance::DistanceType::LpUnexpanded:
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case distance::DistanceType::Canberra:
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case distance::DistanceType::BrayCurtis:
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case distance::DistanceType::JensenShannon:
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
return faiss::MetricType::METRIC_INNER_PRODUCT;
THROW("MetricType not supported: %d", metric);
}
}

template <typename value_t>
DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) {
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
value_t sin_0 = sin(0.5 * (x1 - y1));
value_t sin_1 = sin(0.5 * (x2 - y2));
value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1;

return 2 * asin(sqrt(rdist));
}

/**
* @tparam value_idx data type of indices
* @tparam value_t data type of values and distances
* @tparam warp_q
* @tparam thread_q
* @tparam tpb
* @param[out] out_inds output indices
* @param[out] out_dists output distances
* @param[in] index index array
* @param[in] query query array
* @param[in] n_index_rows number of rows in index array
* @param[in] k number of closest neighbors to return
*/
template <typename value_idx, typename value_t, int warp_q = 1024,
int thread_q = 8, int tpb = 128>
__global__ void haversine_knn_kernel(value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, int k) {
constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;

__shared__ value_t smemK[kNumWarps * warp_q];
__shared__ value_idx smemV[kNumWarps * warp_q];

faiss::gpu::BlockSelect<value_t, value_idx, false,
faiss::gpu::Comparator<value_t>, warp_q, thread_q,
tpb>
heap(faiss::gpu::Limits<value_t>::getMax(), -1, smemK, smemV, k);

// Grid is exactly sized to rows available
int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize);

const value_t *query_ptr = query + (blockIdx.x * 2);
value_t x1 = query_ptr[0];
value_t x2 = query_ptr[1];

int i = threadIdx.x;

for (; i < limit; i += tpb) {
const value_t *idx_ptr = index + (i * 2);
value_t y1 = idx_ptr[0];
value_t y2 = idx_ptr[1];

value_t dist = compute_haversine(x1, y1, x2, y2);

heap.add(dist, i);
}

// Handle last remainder fraction of a warp of elements
if (i < n_index_rows) {
const value_t *idx_ptr = index + (i * 2);
value_t y1 = idx_ptr[0];
value_t y2 = idx_ptr[1];

value_t dist = compute_haversine(x1, y1, x2, y2);

heap.addThreadQ(dist, i);
}

heap.reduce();

for (int i = threadIdx.x; i < k; i += tpb) {
out_dists[blockIdx.x * k + i] = smemK[i];
out_inds[blockIdx.x * k + i] = smemV[i];
}
}

/**
* Conmpute the k-nearest neighbors using the Haversine
* (great circle arc) distance. Input is assumed to have
* 2 dimensions (latitude, longitude) in radians.

* @tparam value_idx
* @tparam value_t
* @param[out] out_inds output indices array on device (size n_query_rows * k)
* @param[out] out_dists output dists array on device (size n_query_rows * k)
* @param[in] index input index array on device (size n_index_rows * 2)
* @param[in] query input query array on device (size n_query_rows * 2)
* @param[in] n_index_rows number of rows in index array
* @param[in] n_query_rows number of rows in query array
* @param[in] k number of closest neighbors to return
* @param[in] stream stream to order kernel launch
*/
template <typename value_idx, typename value_t>
void haversine_knn(value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, size_t n_query_rows, int k,
cudaStream_t stream) {
haversine_knn_kernel<<<n_query_rows, 128, 0, stream>>>(
out_inds, out_dists, index, query, n_index_rows, k);
}

/**
* Search the kNN for the k-nearest neighbors of a set of query vectors
* @param[in] input vector of device device memory array pointers to search
Expand All @@ -209,33 +322,33 @@ inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) {
* @param[in] rowMajorQuery are the query array in row-major layout?
* @param[in] translations translation ids for indices when index rows represent
* non-contiguous partitions
* @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean)
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
* @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root
*/
template <typename IntType = int>
void brute_force_knn_impl(
std::vector<float *> &input, std::vector<int> &sizes, IntType D,
float *search_items, IntType n, int64_t *res_I, float *res_D, IntType k,
std::shared_ptr<raft::mr::device::allocator> allocator,
cudaStream_t userStream, cudaStream_t *internalStreams = nullptr,
int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
float metricArg = 2.0, bool expanded_form = false) {
void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
IntType D, float *search_items, IntType n,
int64_t *res_I, float *res_D, IntType k,
std::shared_ptr<deviceAllocator> allocator,
cudaStream_t userStream,
cudaStream_t *internalStreams = nullptr,
int n_int_streams = 0, bool rowMajorIndex = true,
bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metricArg = 0) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors should be the same size");

faiss::MetricType m = detail::build_faiss_metric(metric);

std::vector<int64_t> *id_ranges;
if (translations == nullptr) {
// If we don't have explicit translations
// for offsets of the indices, build them
// from the local partitions
id_ranges = new std::vector<int64_t>();
int64_t total_n = 0;
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
id_ranges->push_back(total_n);
total_n += sizes[i];
}
Expand All @@ -252,7 +365,7 @@ void brute_force_knn_impl(

std::vector<std::unique_ptr<MetricProcessor<float>>> metric_processors(
input.size());
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
metric_processors[i] = create_processor<float>(
metric, sizes[i], D, k, rowMajorQuery, userStream, allocator);
metric_processors[i]->preprocess(input[i]);
Expand Down Expand Up @@ -283,35 +396,52 @@ void brute_force_knn_impl(
// Sync user stream only if using other streams to parallelize query
if (n_int_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream));

for (size_t i = 0; i < input.size(); i++) {
faiss::gpu::StandardGpuResources gpu_res;
for (int i = 0; i < input.size(); i++) {
float *out_d_ptr = out_D + (i * k * n);
int64_t *out_i_ptr = out_I + (i * k * n);

cudaStream_t stream =
raft::select_stream(userStream, internalStreams, n_int_streams, i);

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_D + (i * k * n);
args.outIndices = out_I + (i * k * n);

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n,
k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
}

CUDA_CHECK(cudaPeekAtLastError());
}
Expand All @@ -326,32 +456,33 @@ void brute_force_knn_impl(
if (input.size() > 1 || translations != nullptr) {
// This is necessary for proper index translations. If there are
// no translations or partitions to combine, it can be skipped.
detail::knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k,
userStream, trans.data());
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream,
trans.data());
}

// Perform necessary post-processing
if ((m == faiss::MetricType::METRIC_L2 ||
m == faiss::MetricType::METRIC_Lp) &&
!expanded_form) {
if (metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::LpUnexpanded) {
/**
* post-processing
*/
float p = 0.5; // standard l2
if (m == faiss::MetricType::METRIC_Lp) p = 1.0 / metricArg;
if (metric == raft::distance::DistanceType::LpUnexpanded)
p = 1.0 / metricArg;
raft::linalg::unaryOp<float>(
res_D, res_D, n * k,
[p] __device__(float input) { return powf(input, p); }, userStream);
}

query_metric_processor->revert(search_items);
query_metric_processor->postprocess(out_D);
for (size_t i = 0; i < input.size(); i++) {
for (int i = 0; i < input.size(); i++) {
metric_processors[i]->revert(input[i]);
}

if (translations == nullptr) delete id_ranges;
}
};

} // namespace detail
} // namespace knn
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/raft/spatial/knn/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ inline void brute_force_knn(
float *res_D, int k, bool rowMajorIndex = false, bool rowMajorQuery = false,
std::vector<int64_t> *translations = nullptr,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
float metric_arg = 2.0f, bool expanded = false) {
float metric_arg = 2.0f) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors must be the same size");

std::vector<cudaStream_t> int_streams = handle.get_internal_streams();

detail::brute_force_knn_impl(
input, sizes, D, search_items, n, res_I, res_D, k,
handle.get_device_allocator(), handle.get_stream(), int_streams.data(),
handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery,
translations, metric, metric_arg, expanded);
detail::brute_force_knn_impl(input, sizes, D, search_items, n, res_I, res_D,
k, handle.get_device_allocator(),
handle.get_stream(), int_streams.data(),
handle.get_num_internal_streams(), rowMajorIndex,
rowMajorQuery, translations, metric, metric_arg);
}

} // namespace knn
Expand Down