Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Remaining sparse semiring distances #261

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5b28699
moving sparse dist optim to raft
divyegala Mar 18, 2021
8e841ac
bucketing bloom filter
divyegala Mar 24, 2021
efcf972
trying stuff
divyegala Mar 26, 2021
8f5cc26
dropping tpb as template in strategies
divyegala Mar 29, 2021
78aff5a
New distances
cjnolet Apr 1, 2021
9aa58f2
Uncommenting hash strategy
cjnolet Apr 1, 2021
cf57c9f
Udating hash strategy
cjnolet Apr 1, 2021
530dd87
Updating to add hash table and bloom strategies
cjnolet Apr 1, 2021
62b3df4
Updates
cjnolet Apr 1, 2021
28d6e2d
More updates
cjnolet Apr 2, 2021
c79ee7a
Updating hash strategy
cjnolet Apr 2, 2021
e5d93dd
Merge branch 'branch-0.19' into HEAD
cjnolet Apr 2, 2021
628b169
Updates
cjnolet Apr 7, 2021
fc708a9
Merge branch 'branch-0.19' into fea-020-sparse_spmv_optim
cjnolet Apr 7, 2021
d9ecc40
trying to merge chunking bug fixes
divyegala Apr 7, 2021
d299cb8
correcting int to value_idx
divyegala Apr 7, 2021
900ad8c
correction to expansion and start index
divyegala Apr 7, 2021
634bc99
Adding correlation distance
cjnolet Apr 7, 2021
82cdbe1
Baseline
cjnolet Apr 9, 2021
4e819c0
Enabling baseline
cjnolet Apr 9, 2021
6bd450f
Enabling optimized primitive
cjnolet Apr 9, 2021
f1c1d06
Fixing style
cjnolet Apr 19, 2021
219a813
Cleanup
cjnolet Apr 19, 2021
0fee984
Style
cjnolet Apr 19, 2021
b7105d7
Removing unecessary deltas
cjnolet Apr 20, 2021
e264514
Updating distances_config_t to use handle directly
cjnolet Apr 20, 2021
65f82bd
Merge branch 'branch-0.20' into semiring_primitives_optim_final
cjnolet May 4, 2021
2975fe3
Adding tests for newer distances
cjnolet May 4, 2021
0d064cc
NOrmalizing
cjnolet May 20, 2021
2057985
Separating new distances from optimizations
cjnolet May 21, 2021
3a669f5
Fixing style
cjnolet May 21, 2021
a415ebf
Merge branch 'branch-21.06' into semiring_primitives_optim_final
cjnolet May 21, 2021
99b4e14
Trying to get cuco working
cjnolet May 21, 2021
46c7d10
Removing dependencies.cmake
cjnolet May 21, 2021
ef9efa7
Raft is building all gpu archs. Checking this in the meantime
cjnolet May 21, 2021
a9d3608
changing cuo to dev branch
divyegala May 26, 2021
d613a77
working through build
divyegala Jun 7, 2021
2394084
fixing build
divyegala Jun 7, 2021
09b5b2c
Updating correlation
cjnolet Jun 7, 2021
339ae67
Merge branch 'semiring_primitives_optim_final' into semiring_primitiv…
cjnolet Jun 7, 2021
b40d7ce
Corr updates
cjnolet Jun 7, 2021
d0e0ea7
Yes! Correlation distance works!
cjnolet Jun 7, 2021
2d3855d
Adding russelrao dist
cjnolet Jun 7, 2021
69ff625
tests passing for all strategies
divyegala Jun 8, 2021
67562a0
Adding russelrao
cjnolet Jun 8, 2021
c43c496
Merge branch 'semiring_primitives_optim_final' into semiring_primitiv…
cjnolet Jun 8, 2021
149d67f
Getting tests to work for russellrao
cjnolet Jun 8, 2021
deb01b2
Adding hamming distance
cjnolet Jun 8, 2021
fad730a
Adding jensenshannon and kldivergence
cjnolet Jun 8, 2021
c97e8e4
Testing remaining distances
cjnolet Jun 8, 2021
7739d1f
integrating cuco changes and some small refactors
divyegala Jun 9, 2021
1098d89
merging upstream
divyegala Jun 9, 2021
a3e263b
removing thrust::device_vector usage because stream unsafe
divyegala Jun 15, 2021
b991081
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 15, 2021
6ef7a02
restructuring tests and addressing review related to tests
divyegala Jun 15, 2021
9b63fc8
addressing other review comments
divyegala Jun 15, 2021
a270833
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 15, 2021
cca6fa6
pointing back to dev cuco commit
divyegala Jun 23, 2021
6c2f707
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 23, 2021
ea597d6
removing print
divyegala Jun 23, 2021
318a95a
Merge remote-tracking branch 'divye/semiring_primitives_optim_final' …
cjnolet Jun 23, 2021
591f9f2
Merge branch 'branch-21.08' into semiring_primitive_additional_distances
cjnolet Jun 23, 2021
fc185e5
Additional consts needed
cjnolet Jun 24, 2021
924d8db
Using rmm::device_uvector
cjnolet Jul 7, 2021
72dd60e
More removal of device buffer
cjnolet Jul 7, 2021
cc59ad6
Merge branch 'branch-21.08' of github.com:rapidsai/raft into branch-2…
cjnolet Jul 8, 2021
1da9ef0
Merge branch 'branch-21.08' of github.com:rapidsai/raft into branch-2…
cjnolet Jul 12, 2021
6e68e40
Merge branch 'branch-21.08' into semiring_primitive_additional_distances
cjnolet Jul 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions cpp/include/raft/sparse/distance/bin_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <raft/cuda_utils.cuh>

#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>

#include <raft/sparse/distance/common.h>
#include <raft/sparse/utils.h>
Expand Down Expand Up @@ -87,8 +86,8 @@ void compute_bin_distance(value_t *out, const value_idx *Q_coo_rows,
value_idx R_nnz, value_idx m, value_idx n,
std::shared_ptr<raft::mr::device::allocator> alloc,
cudaStream_t stream, expansion_f expansion_func) {
raft::mr::device::buffer<value_t> Q_norms(alloc, stream, m);
raft::mr::device::buffer<value_t> R_norms(alloc, stream, n);
rmm::device_uvector<value_t> Q_norms(m, stream);
rmm::device_uvector<value_t> R_norms(n, stream);
CUDA_CHECK(
cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t)));
CUDA_CHECK(
Expand All @@ -113,8 +112,7 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
explicit jaccard_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
workspace(0, config.handle.get_stream()),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -123,9 +121,8 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
value_idx *b_indices = ip_dists.b_rows_coo();
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
rmm::device_uvector<value_idx> search_coo_rows(
config_->a_nnz, config_->handle.get_stream());
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->handle.get_stream());
Expand All @@ -150,7 +147,7 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
rmm::device_uvector<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

Expand All @@ -164,8 +161,7 @@ class dice_expanded_distances_t : public distances_t<value_t> {
explicit dice_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
workspace(0, config.handle.get_stream()),
ip_dists(config) {}

void compute(value_t *out_dists) {
Expand All @@ -174,9 +170,8 @@ class dice_expanded_distances_t : public distances_t<value_t> {
value_idx *b_indices = ip_dists.b_rows_coo();
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
rmm::device_uvector<value_idx> search_coo_rows(
config_->a_nnz, config_->handle.get_stream());
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->handle.get_stream());
Expand All @@ -197,7 +192,7 @@ class dice_expanded_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
rmm::device_uvector<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

Expand Down
27 changes: 26 additions & 1 deletion cpp/include/raft/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded,
raft::distance::DistanceType::DiceExpanded};
raft::distance::DistanceType::DiceExpanded,
raft::distance::DistanceType::CorrelationExpanded,
raft::distance::DistanceType::RusselRaoExpanded,
raft::distance::DistanceType::HammingUnexpanded,
raft::distance::DistanceType::JensenShannon,
raft::distance::DistanceType::KLDivergence};

/**
* Compute pairwise distances between A and B, using the provided
Expand Down Expand Up @@ -120,6 +125,26 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::DiceExpanded:
dice_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::CorrelationExpanded:
correlation_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
russelrao_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::HammingUnexpanded:
hamming_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::JensenShannon:
jensen_shannon_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::KLDivergence:
kl_divergence_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;

default:
THROW("Unsupported distance: %d", metric);
Expand Down
7 changes: 2 additions & 5 deletions cpp/include/raft/sparse/distance/ip_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <raft/cuda_utils.cuh>

#include <raft/mr/device/allocator.hpp>
#include <raft/mr/device/buffer.hpp>

#include <raft/sparse/distance/common.h>
#include <raft/sparse/linalg/transpose.h>
Expand All @@ -47,9 +46,7 @@ class ip_distances_t : public distances_t<value_t> {
* @param[in] config specifies inputs, outputs, and sizes
*/
ip_distances_t(const distances_config_t<value_idx, value_t> &config)
: config_(&config),
coo_rows_b(config.handle.get_device_allocator(),
config.handle.get_stream(), config.b_nnz) {
: config_(&config), coo_rows_b(config.b_nnz, config.handle.get_stream()) {
raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows,
coo_rows_b.data(), config_->b_nnz,
config_->handle.get_stream());
Expand All @@ -74,7 +71,7 @@ class ip_distances_t : public distances_t<value_t> {

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<value_idx> coo_rows_b;
rmm::device_uvector<value_idx> coo_rows_b;
};
}; // END namespace distance
}; // END namespace sparse
Expand Down
Loading