Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Updating sparse prims based on recent changes #166

Merged
merged 6 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion cpp/include/raft/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <unordered_set>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand All @@ -42,6 +43,20 @@ namespace raft {
namespace sparse {
namespace distance {

static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::L2Expanded,
raft::distance::DistanceType::L2Unexpanded,
raft::distance::DistanceType::L2SqrtExpanded,
raft::distance::DistanceType::L2SqrtUnexpanded,
raft::distance::DistanceType::InnerProduct,
raft::distance::DistanceType::L1,
raft::distance::DistanceType::Canberra,
raft::distance::DistanceType::Linf,
raft::distance::DistanceType::LpUnexpanded,
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded};

/**
* Compute pairwise distances between A and B, using the provided
* input configuration and distance function.
Expand All @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::L2Expanded:
l2_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
l2_sqrt_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::InnerProduct:
ip_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2Unexpanded:
l2_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
l2_sqrt_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::L1:
l1_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
Expand Down
6 changes: 0 additions & 6 deletions cpp/include/raft/sparse/distance/ip_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,12 @@ class ip_distances_gemm_t : public ip_trans_getters_t<value_idx, value_t> {
value_t *csr_out_data) {
value_idx m = config_->a_nrows, n = config_->b_nrows, k = config_->a_ncols;

int start = raft::curTimeMillis();

CUDA_CHECK(cudaStreamSynchronize(config_->stream));

CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2<value_t>(
config_->handle, m, n, k, &alpha, matA, config_->a_nnz, config_->a_data,
config_->a_indptr, config_->a_indices, matB, config_->b_nnz,
csc_data.data(), csc_indptr.data(), csc_indices.data(), NULL, matD, 0,
NULL, NULL, NULL, matC, csr_out_data, csr_out_indptr, csr_out_indices,
info, workspace.data(), config_->stream));

CUDA_CHECK(cudaStreamSynchronize(config_->stream));
}

void transpose_b() {
Expand Down
37 changes: 30 additions & 7 deletions cpp/include/raft/sparse/distance/l2_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,13 @@

#pragma once

#include <limits.h>
#include <cmath>

#include <raft/cudart_utils.h>
#include <raft/sparse/distance/common.h>
#include <raft/spatial/knn/knn.hpp>

#include <raft/cudart_utils.h>
#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
#include <raft/cuda_utils.cuh>
#include <raft/linalg/unary_op.cuh>

#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>

Expand Down Expand Up @@ -149,12 +144,40 @@ class l2_expanded_distances_t : public distances_t<value_t> {

~l2_expanded_distances_t() = default;

private:
protected:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

/**
* L2 sqrt distance performing the sqrt operation after the distance computation
* The expanded form is more efficient for sparse data.
*/
template <typename value_idx = int, typename value_t = float>
class l2_sqrt_expanded_distances_t
: public l2_expanded_distances_t<value_idx, value_t> {
public:
explicit l2_sqrt_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: l2_expanded_distances_t<value_idx, value_t>(config) {}

void compute(value_t *out_dists) override {
l2_expanded_distances_t<value_idx, value_t>::compute(out_dists);
// Sqrt Post-processing
value_t p = 0.5; // standard l2
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
raft::linalg::unaryOp<value_t>(
out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows,
[p] __device__(value_t input) {
int neg = input < 0 ? -1 : 1;
return powf(fabs(input), p) * neg;
},
this->config_->stream);
}

~l2_sqrt_expanded_distances_t() = default;
};

/**
* Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * sqrt(sum(y_k)^2)))
* The expanded form is more efficient for sparse data.
Expand Down
24 changes: 23 additions & 1 deletion cpp/include/raft/sparse/distance/lp_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,32 @@ class l2_unexpanded_distances_t : public distances_t<value_t> {
Sum(), AtomicAdd());
}

private:
protected:
const distances_config_t<value_idx, value_t> *config_;
};

template <typename value_idx = int, typename value_t = float>
class l2_sqrt_unexpanded_distances_t
: public l2_unexpanded_distances_t<value_idx, value_t> {
public:
l2_sqrt_unexpanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: l2_unexpanded_distances_t<value_idx, value_t>(config) {}

void compute(value_t *out_dists) {
l2_unexpanded_distances_t<value_idx, value_t>::compute(out_dists);
// Sqrt Post-processing
value_t p = 0.5; // standard l2
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
raft::linalg::unaryOp<value_t>(
out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows,
[p] __device__(value_t input) {
int neg = input < 0 ? -1 : 1;
return powf(fabs(input), p) * neg;
},
this->config_->stream);
}
};

template <typename value_idx = int, typename value_t = float>
class linf_unexpanded_distances_t : public distances_t<value_t> {
public:
Expand Down
84 changes: 27 additions & 57 deletions cpp/include/raft/sparse/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <rmm/device_uvector.hpp>

#include <raft/cudart_utils.h>
#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand All @@ -34,13 +36,10 @@
#include <raft/spatial/knn/detail/brute_force_knn.hpp>
#include <raft/spatial/knn/knn.hpp>

#include <raft/linalg/distance_type.h>

#include <raft/cudart_utils.h>

#include <raft/cuda_utils.cuh>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
#include <raft/cuda_utils.cuh>

#include <cusparse_v2.h>

Expand Down Expand Up @@ -128,7 +127,7 @@ class sparse_knn_t {
size_t batch_size_query_ = 2 << 14,
raft::distance::DistanceType metric_ =
raft::distance::DistanceType::L2Expanded,
float metricArg_ = 0, bool expanded_form_ = false)
float metricArg_ = 0)
: idxIndptr(idxIndptr_),
idxIndices(idxIndices_),
idxData(idxData_),
Expand All @@ -150,8 +149,7 @@ class sparse_knn_t {
batch_size_index(batch_size_index_),
batch_size_query(batch_size_query_),
metric(metric_),
metricArg(metricArg_),
expanded_form(expanded_form_) {}
metricArg(metricArg_) {}

void run() {
using namespace raft::sparse;
Expand All @@ -172,26 +170,23 @@ class sparse_knn_t {
* Slice CSR to rows in batch
*/

raft::mr::device::buffer<value_idx> query_batch_indptr(
allocator, stream, query_batcher.batch_rows() + 1);
rmm::device_uvector<value_idx> query_batch_indptr(
query_batcher.batch_rows() + 1, stream);

value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz(
query_batch_indptr.data(), stream);

raft::mr::device::buffer<value_idx> query_batch_indices(
allocator, stream, n_query_batch_nnz);
raft::mr::device::buffer<value_t> query_batch_data(allocator, stream,
n_query_batch_nnz);
rmm::device_uvector<value_idx> query_batch_indices(n_query_batch_nnz,
stream);
rmm::device_uvector<value_t> query_batch_data(n_query_batch_nnz, stream);

query_batcher.get_batch_csr_indices_data(query_batch_indices.data(),
query_batch_data.data(), stream);

// A 3-partition temporary merge space to scale the batching. 2 parts for subsequent
// batches and 1 space for the results of the merge, which get copied back to the top
raft::mr::device::buffer<value_idx> merge_buffer_indices(allocator,
stream, 0);
raft::mr::device::buffer<value_t> merge_buffer_dists(allocator, stream,
0);
rmm::device_uvector<value_idx> merge_buffer_indices(0, stream);
rmm::device_uvector<value_t> merge_buffer_dists(0, stream);

value_t *dists_merge_buffer_ptr;
value_idx *indices_merge_buffer_ptr;
Expand All @@ -209,11 +204,10 @@ class sparse_knn_t {
/**
* Slice CSR to rows in batch
*/
raft::mr::device::buffer<value_idx> idx_batch_indptr(
allocator, stream, idx_batcher.batch_rows() + 1);
raft::mr::device::buffer<value_idx> idx_batch_indices(allocator, stream,
0);
raft::mr::device::buffer<value_t> idx_batch_data(allocator, stream, 0);
rmm::device_uvector<value_idx> idx_batch_indptr(
idx_batcher.batch_rows() + 1, stream);
rmm::device_uvector<value_idx> idx_batch_indices(0, stream);
rmm::device_uvector<value_t> idx_batch_data(0, stream);

value_idx idx_batch_nnz =
idx_batcher.get_batch_csr_indptr_nnz(idx_batch_indptr.data(), stream);
Expand All @@ -229,8 +223,7 @@ class sparse_knn_t {
*/
size_t dense_size =
idx_batcher.batch_rows() * query_batcher.batch_rows();
raft::mr::device::buffer<value_t> batch_dists(allocator, stream,
dense_size);
rmm::device_uvector<value_t> batch_dists(dense_size, stream);

CUDA_CHECK(cudaMemset(batch_dists.data(), 0,
batch_dists.size() * sizeof(value_t)));
Expand All @@ -241,13 +234,9 @@ class sparse_knn_t {
query_batch_indptr.data(), query_batch_indices.data(),
query_batch_data.data(), batch_dists.data());

idx_batch_indptr.release(stream);
idx_batch_indices.release(stream);
idx_batch_data.release(stream);

// Build batch indices array
raft::mr::device::buffer<value_idx> batch_indices(allocator, stream,
batch_dists.size());
rmm::device_uvector<value_idx> batch_indices(batch_dists.size(),
stream);

// populate batch indices array
value_idx batch_rows = query_batcher.batch_rows(),
Expand All @@ -268,8 +257,6 @@ class sparse_knn_t {
batch_indices.data(), dists_merge_buffer_ptr,
indices_merge_buffer_ptr);

perform_postprocessing(dists_merge_buffer_ptr, batch_rows);

value_t *dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr;
value_idx *indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr;

Expand Down Expand Up @@ -307,23 +294,6 @@ class sparse_knn_t {
}
}

void perform_postprocessing(value_t *dists, size_t batch_rows) {
// Perform necessary post-processing
if (metric == raft::distance::DistanceType::L2Expanded && !expanded_form) {
/**
* post-processing
*/
value_t p = 0.5; // standard l2
raft::linalg::unaryOp<value_t>(
dists, dists, batch_rows * k,
[p] __device__(value_t input) {
int neg = input < 0 ? -1 : 1;
return powf(fabs(input), p) * neg;
},
stream);
}
}

private:
void merge_batches(csr_batcher_t<value_idx, value_t> &idx_batcher,
csr_batcher_t<value_idx, value_t> &query_batcher,
Expand All @@ -335,8 +305,7 @@ class sparse_knn_t {
id_ranges.push_back(0);
id_ranges.push_back(idx_batcher.batch_start());

raft::mr::device::buffer<value_idx> trans(allocator, stream,
id_ranges.size());
rmm::device_uvector<value_idx> trans(id_ranges.size(), stream);
raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(),
stream);

Expand Down Expand Up @@ -403,6 +372,10 @@ class sparse_knn_t {
dist_config.allocator = allocator;
dist_config.stream = stream;

if (raft::sparse::distance::supportedDistance.find(metric) ==
raft::sparse::distance::supportedDistance.end())
THROW("DistanceType not supported: %d", metric);

raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, metric,
metricArg);
}
Expand All @@ -418,8 +391,6 @@ class sparse_knn_t {

float metricArg;

bool expanded_form;

int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k;

cusparseHandle_t cusparseHandle;
Expand Down Expand Up @@ -453,7 +424,6 @@ class sparse_knn_t {
* @param[in] batch_size_query maximum number of rows to use from query matrix per batch
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
* @param[in] expanded_form whether or not Lp variants should be reduced by the pth-root
*/
template <typename value_idx = int, typename value_t = float, int TPB_X = 32>
void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices,
Expand All @@ -469,12 +439,12 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices,
size_t batch_size_query = 2 << 14,
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metricArg = 0, bool expanded_form = false) {
float metricArg = 0) {
sparse_knn_t<value_idx, value_t>(
idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr,
queryIndices, queryData, queryNNZ, n_query_rows, n_query_cols,
output_indices, output_dists, k, cusparseHandle, allocator, stream,
batch_size_index, batch_size_query, metric, metricArg, expanded_form)
batch_size_index, batch_size_query, metric, metricArg)
.run();
}

Expand Down
Loading