Skip to content

Commit

Permalink
Sparse semirings cleanup + hash table & batching strategies (#269)
Browse files Browse the repository at this point in the history
This branch includes several new features and optimizations:

1. Introduces a hash table strategy to sparsify the vector in the coo spmv shared memory
2. Adds a batching strategy for rows with nnz too large to fit into shared memory
3. Removes the need for the cusparse csrgemm
4. Uses raft handle in distances_config_t rather than accepting each resource explicitly
5. Removes the naive CSR semiring code

This PR is also required to merge #261, which introduces the remaining distances

Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #269
  • Loading branch information
divyegala authored Jun 23, 2021
1 parent 2ba5d76 commit caa44e6
Show file tree
Hide file tree
Showing 22 changed files with 1,309 additions and 1,936 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_cuco.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function(find_and_configure_cuco VERSION)
GLOBAL_TARGETS cuco cuco::cuco
CPM_ARGS
GIT_REPOSITORY https://github.com/NVIDIA/cuCollections.git
GIT_TAG 0b672bbde7c85a79df4d7ca5f82e15e5b4a57700
GIT_TAG e5e2abe55152608ef449ecf162a1ef52ded19801
OPTIONS "BUILD_TESTS OFF"
"BUILD_BENCHMARKS OFF"
"BUILD_EXAMPLES OFF"
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/linalg/distance_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,16 @@ enum DistanceType : unsigned short {
Haversine = 13,
/** Bray-Curtis distance **/
BrayCurtis = 14,
/** Jensen-Shannon distance **/
/** Jensen-Shannon distance**/
JensenShannon = 15,
/** Hamming distance **/
HammingUnexpanded = 16,
/** KLDivergence **/
KLDivergence = 17,
/** RusselRao **/
RusselRaoExpanded = 18,
/** Dice-Sorensen distance **/
DiceExpanded = 16,
DiceExpanded = 19,
/** Precomputed (special value) **/
Precomputed = 100
};
Expand Down
21 changes: 12 additions & 9 deletions cpp/include/raft/sparse/distance/bin_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ void compute_bin_distance(value_t *out, const value_idx *Q_coo_rows,
const value_t *Q_data, value_idx Q_nnz,
const value_idx *R_coo_rows, const value_t *R_data,
value_idx R_nnz, value_idx m, value_idx n,
cusparseHandle_t handle,
std::shared_ptr<raft::mr::device::allocator> alloc,
cudaStream_t stream, expansion_f expansion_func) {
raft::mr::device::buffer<value_t> Q_norms(alloc, stream, m);
Expand Down Expand Up @@ -114,7 +113,8 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
explicit jaccard_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.allocator, config.stream, 0),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -124,15 +124,16 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->allocator, config_->stream, config_->a_nnz);
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->stream);
config_->handle.get_stream());

compute_bin_distance(
out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz,
b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows,
config_->handle, config_->allocator, config_->stream,
config_->handle.get_device_allocator(), config_->handle.get_stream(),
[] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) {
value_t q_r_union = q_norm + r_norm;
value_t denom = q_r_union - dot;
Expand Down Expand Up @@ -163,7 +164,8 @@ class dice_expanded_distances_t : public distances_t<value_t> {
explicit dice_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.allocator, config.stream, 0),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -173,15 +175,16 @@ class dice_expanded_distances_t : public distances_t<value_t> {
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->allocator, config_->stream, config_->a_nnz);
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->stream);
config_->handle.get_stream());

compute_bin_distance(
out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz,
b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows,
config_->handle, config_->allocator, config_->stream,
config_->handle.get_device_allocator(), config_->handle.get_stream(),
[] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) {
value_t q_r_union = q_norm + r_norm;
value_t dice = (2 * dot) / q_r_union;
Expand Down
10 changes: 4 additions & 6 deletions cpp/include/raft/sparse/distance/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

#pragma once

#include <cusparse_v2.h>
#include <raft/mr/device/allocator.hpp>
#include <raft/handle.hpp>

namespace raft {
namespace sparse {
namespace distance {

template <typename value_idx, typename value_t>
struct distances_config_t {
distances_config_t(raft::handle_t &handle_) : handle(handle_) {}

// left side
value_idx a_nrows;
value_idx a_ncols;
Expand All @@ -41,10 +42,7 @@ struct distances_config_t {
value_idx *b_indices;
value_t *b_data;

cusparseHandle_t handle;

std::shared_ptr<raft::mr::device::allocator> allocator;
cudaStream_t stream;
raft::handle_t &handle;
};

template <typename value_t>
Expand Down
Loading

0 comments on commit caa44e6

Please sign in to comment.