Skip to content

Commit

Permalink
Remaining sparse semiring distances (#261)
Browse files Browse the repository at this point in the history
This PR is intended to be merged after #207 (hash table strategy) has been merged.

This PR introduces the following distances:
- Hamming
- Jensen-Shannon
- Russell-Rao
- KL-Divergence
- Correlation

Most of the changes here are from #207 and will be reviewed in that PR. The only files that need to be reviewed for this PR are `sparse/distance/l2_distance.cuh`, `sparse/distance/bin_distance.cuh`, `sparse/distance/lp_distances.cuh`, and their corresponding gtests: `test/sparse/distance.cuh`

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #261
  • Loading branch information
cjnolet authored Jul 12, 2021
1 parent 22a16dd commit f94780c
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 45 deletions.
25 changes: 10 additions & 15 deletions cpp/include/raft/sparse/distance/bin_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <raft/cuda_utils.cuh>

#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>

#include <raft/sparse/distance/common.h>
#include <raft/sparse/utils.h>
Expand Down Expand Up @@ -87,8 +86,8 @@ void compute_bin_distance(value_t *out, const value_idx *Q_coo_rows,
value_idx R_nnz, value_idx m, value_idx n,
std::shared_ptr<raft::mr::device::allocator> alloc,
cudaStream_t stream, expansion_f expansion_func) {
raft::mr::device::buffer<value_t> Q_norms(alloc, stream, m);
raft::mr::device::buffer<value_t> R_norms(alloc, stream, n);
rmm::device_uvector<value_t> Q_norms(m, stream);
rmm::device_uvector<value_t> R_norms(n, stream);
CUDA_CHECK(
cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t)));
CUDA_CHECK(
Expand All @@ -113,8 +112,7 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
explicit jaccard_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
workspace(0, config.handle.get_stream()),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -123,9 +121,8 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
value_idx *b_indices = ip_dists.b_rows_coo();
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
rmm::device_uvector<value_idx> search_coo_rows(
config_->a_nnz, config_->handle.get_stream());
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->handle.get_stream());
Expand All @@ -150,7 +147,7 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
rmm::device_uvector<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

Expand All @@ -164,8 +161,7 @@ class dice_expanded_distances_t : public distances_t<value_t> {
explicit dice_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
workspace(0, config.handle.get_stream()),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -174,9 +170,8 @@ class dice_expanded_distances_t : public distances_t<value_t> {
value_idx *b_indices = ip_dists.b_rows_coo();
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
rmm::device_uvector<value_idx> search_coo_rows(
config_->a_nnz, config_->handle.get_stream());
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->handle.get_stream());
Expand All @@ -197,7 +192,7 @@ class dice_expanded_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
rmm::device_uvector<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

Expand Down
27 changes: 26 additions & 1 deletion cpp/include/raft/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded,
raft::distance::DistanceType::DiceExpanded};
raft::distance::DistanceType::DiceExpanded,
raft::distance::DistanceType::CorrelationExpanded,
raft::distance::DistanceType::RusselRaoExpanded,
raft::distance::DistanceType::HammingUnexpanded,
raft::distance::DistanceType::JensenShannon,
raft::distance::DistanceType::KLDivergence};

/**
* Compute pairwise distances between A and B, using the provided
Expand Down Expand Up @@ -120,6 +125,26 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::DiceExpanded:
dice_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::CorrelationExpanded:
correlation_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
russelrao_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::HammingUnexpanded:
hamming_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::JensenShannon:
jensen_shannon_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::KLDivergence:
kl_divergence_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;

default:
THROW("Unsupported distance: %d", metric);
Expand Down
7 changes: 2 additions & 5 deletions cpp/include/raft/sparse/distance/ip_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <raft/cuda_utils.cuh>

#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>

#include <raft/sparse/distance/common.h>
#include <raft/sparse/linalg/transpose.h>
Expand All @@ -47,9 +46,7 @@ class ip_distances_t : public distances_t<value_t> {
* @param[in] config specifies inputs, outputs, and sizes
*/
ip_distances_t(const distances_config_t<value_idx, value_t> &config)
: config_(&config),
coo_rows_b(config.handle.get_device_allocator(),
config.handle.get_stream(), config.b_nnz) {
: config_(&config), coo_rows_b(config.b_nnz, config.handle.get_stream()) {
raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows,
coo_rows_b.data(), config_->b_nnz,
config_->handle.get_stream());
Expand All @@ -74,7 +71,7 @@ class ip_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<value_idx> coo_rows_b;
rmm::device_uvector<value_idx> coo_rows_b;
};
}; // END namespace distance
}; // END namespace sparse
Expand Down
Loading

0 comments on commit f94780c

Please sign in to comment.