Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Remaining sparse semiring distances #261

Merged
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5b28699
moving sparse dist optim to raft
divyegala Mar 18, 2021
8e841ac
bucketing bloom filter
divyegala Mar 24, 2021
efcf972
trying stuff
divyegala Mar 26, 2021
8f5cc26
dropping tpb as template in strategies
divyegala Mar 29, 2021
78aff5a
New distances
cjnolet Apr 1, 2021
9aa58f2
Uncommenting hash strategy
cjnolet Apr 1, 2021
cf57c9f
Udating hash strategy
cjnolet Apr 1, 2021
530dd87
Updating to add hash table and bloom strategies
cjnolet Apr 1, 2021
62b3df4
Updates
cjnolet Apr 1, 2021
28d6e2d
More updates
cjnolet Apr 2, 2021
c79ee7a
Updating hash strategy
cjnolet Apr 2, 2021
e5d93dd
Merge branch 'branch-0.19' into HEAD
cjnolet Apr 2, 2021
628b169
Updates
cjnolet Apr 7, 2021
fc708a9
Merge branch 'branch-0.19' into fea-020-sparse_spmv_optim
cjnolet Apr 7, 2021
d9ecc40
trying to merge chunking bug fixes
divyegala Apr 7, 2021
d299cb8
correcting int to value_idx
divyegala Apr 7, 2021
900ad8c
correction to expansion and start index
divyegala Apr 7, 2021
634bc99
Adding correlation distance
cjnolet Apr 7, 2021
82cdbe1
Baseline
cjnolet Apr 9, 2021
4e819c0
Enabling baseline
cjnolet Apr 9, 2021
6bd450f
Enabling optimized primitive
cjnolet Apr 9, 2021
f1c1d06
Fixing style
cjnolet Apr 19, 2021
219a813
Cleanup
cjnolet Apr 19, 2021
0fee984
Style
cjnolet Apr 19, 2021
b7105d7
Removing unecessary deltas
cjnolet Apr 20, 2021
e264514
Updating distances_config_t to use handle directly
cjnolet Apr 20, 2021
65f82bd
Merge branch 'branch-0.20' into semiring_primitives_optim_final
cjnolet May 4, 2021
2975fe3
Adding tests for newer distances
cjnolet May 4, 2021
0d064cc
NOrmalizing
cjnolet May 20, 2021
2057985
Separating new distances from optimizations
cjnolet May 21, 2021
3a669f5
Fixing style
cjnolet May 21, 2021
a415ebf
Merge branch 'branch-21.06' into semiring_primitives_optim_final
cjnolet May 21, 2021
99b4e14
Trying to get cuco working
cjnolet May 21, 2021
46c7d10
Removing dependencies.cmake
cjnolet May 21, 2021
ef9efa7
Raft is building all gpu archs. Checking this in the meantime
cjnolet May 21, 2021
a9d3608
changing cuo to dev branch
divyegala May 26, 2021
d613a77
working through build
divyegala Jun 7, 2021
2394084
fixing build
divyegala Jun 7, 2021
09b5b2c
Updating correlation
cjnolet Jun 7, 2021
339ae67
Merge branch 'semiring_primitives_optim_final' into semiring_primitiv…
cjnolet Jun 7, 2021
b40d7ce
Corr updates
cjnolet Jun 7, 2021
d0e0ea7
Yes! Correlation distance works!
cjnolet Jun 7, 2021
2d3855d
Adding russelrao dist
cjnolet Jun 7, 2021
69ff625
tests passing for all strategies
divyegala Jun 8, 2021
67562a0
Adding russelrao
cjnolet Jun 8, 2021
c43c496
Merge branch 'semiring_primitives_optim_final' into semiring_primitiv…
cjnolet Jun 8, 2021
149d67f
Getting tests to work for russellrao
cjnolet Jun 8, 2021
deb01b2
Adding hamming distance
cjnolet Jun 8, 2021
fad730a
Adding jensenshannon and kldivergence
cjnolet Jun 8, 2021
c97e8e4
Testing remaining distances
cjnolet Jun 8, 2021
7739d1f
integrating cuco changes and some small refactors
divyegala Jun 9, 2021
1098d89
merging upstream
divyegala Jun 9, 2021
a3e263b
removing thrust::device_vector usage because stream unsafe
divyegala Jun 15, 2021
b991081
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 15, 2021
6ef7a02
restructuring tests and addressing review related to tests
divyegala Jun 15, 2021
9b63fc8
addressing other review comments
divyegala Jun 15, 2021
a270833
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 15, 2021
cca6fa6
pointing back to dev cuco commit
divyegala Jun 23, 2021
6c2f707
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jun 23, 2021
ea597d6
removing print
divyegala Jun 23, 2021
318a95a
Merge remote-tracking branch 'divye/semiring_primitives_optim_final' …
cjnolet Jun 23, 2021
591f9f2
Merge branch 'branch-21.08' into semiring_primitive_additional_distances
cjnolet Jun 23, 2021
fc185e5
Additional consts needed
cjnolet Jun 24, 2021
924d8db
Using rmm::device_uvector
cjnolet Jul 7, 2021
72dd60e
More removal of device buffer
cjnolet Jul 7, 2021
cc59ad6
Merge branch 'branch-21.08' of github.com:rapidsai/raft into branch-2…
cjnolet Jul 8, 2021
1da9ef0
Merge branch 'branch-21.08' of github.com:rapidsai/raft into branch-2…
cjnolet Jul 12, 2021
6e68e40
Merge branch 'branch-21.08' into semiring_primitive_additional_distances
cjnolet Jul 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion cpp/include/raft/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded,
raft::distance::DistanceType::DiceExpanded};
raft::distance::DistanceType::DiceExpanded,
raft::distance::DistanceType::CorrelationExpanded,
raft::distance::DistanceType::RusselRaoExpanded,
raft::distance::DistanceType::HammingUnexpanded,
raft::distance::DistanceType::JensenShannon,
raft::distance::DistanceType::KLDivergence};

/**
* Compute pairwise distances between A and B, using the provided
Expand Down Expand Up @@ -120,6 +125,26 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::DiceExpanded:
dice_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::CorrelationExpanded:
correlation_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
russelrao_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::HammingUnexpanded:
hamming_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::JensenShannon:
jensen_shannon_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::KLDivergence:
kl_divergence_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;

default:
THROW("Unsupported distance: %d", metric);
Expand Down
149 changes: 149 additions & 0 deletions cpp/include/raft/sparse/distance/l2_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ __global__ void compute_euclidean_warp_kernel(
C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001);
}

template <typename value_idx, typename value_t>
__global__ void compute_correlation_warp_kernel(
value_t *__restrict__ C, const value_t *__restrict__ Q_sq_norms,
const value_t *__restrict__ R_sq_norms, const value_t *__restrict__ Q_norms,
const value_t *__restrict__ R_norms, value_idx n_rows, value_idx n_cols,
value_idx n) {
value_idx tid = blockDim.x * blockIdx.x + threadIdx.x;
value_idx i = tid / n_cols;
value_idx j = tid % n_cols;

if (i >= n_rows || j >= n_cols) return;

value_t dot = C[(size_t)i * n_cols + j];
value_t Q_l1 = Q_norms[i];
value_t R_l1 = R_norms[j];

value_t Q_l2 = Q_sq_norms[i];
value_t R_l2 = R_sq_norms[j];

value_t numer = n * dot - (Q_l1 * R_l1);
value_t Q_denom = n * Q_l2 - (Q_l1 * Q_l1);
value_t R_denom = n * R_l2 - (R_l1 * R_l1);

value_t val = 1 - (numer / sqrt(Q_denom * R_denom));

// correct for small instabilities
C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001);
}

template <typename value_idx, typename value_t, int tpb = 256,
typename expansion_f>
void compute_euclidean(value_t *C, const value_t *Q_sq_norms,
Expand Down Expand Up @@ -116,6 +145,55 @@ void compute_l2(value_t *out, const value_idx *Q_coo_rows,
expansion_func);
}

template <typename value_idx, typename value_t, int tpb = 256>
void compute_correlation(value_t *C, const value_t *Q_sq_norms,
const value_t *R_sq_norms, const value_t *Q_norms,
const value_t *R_norms, value_idx n_rows,
value_idx n_cols, value_idx n, cudaStream_t stream) {
int blocks = raft::ceildiv<size_t>((size_t)n_rows * n_cols, tpb);
compute_correlation_warp_kernel<<<blocks, tpb, 0, stream>>>(
C, Q_sq_norms, R_sq_norms, Q_norms, R_norms, n_rows, n_cols, n);
}

template <typename value_idx, typename value_t, int tpb = 256>
void compute_corr(value_t *out, const value_idx *Q_coo_rows,
const value_t *Q_data, value_idx Q_nnz,
const value_idx *R_coo_rows, const value_t *R_data,
value_idx R_nnz, value_idx m, value_idx n, value_idx n_cols,
std::shared_ptr<raft::mr::device::allocator> alloc,
cudaStream_t stream) {
// sum_sq for std dev
raft::mr::device::buffer<value_t> Q_sq_norms(alloc, stream, m);
raft::mr::device::buffer<value_t> R_sq_norms(alloc, stream, n);

// sum for mean
raft::mr::device::buffer<value_t> Q_norms(alloc, stream, m);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it wouldn't be too many changes, could you use rmm uvectors in general instead? (just adding a single comment to avoid repeated ones)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think that would be too much to do for this PR.

raft::mr::device::buffer<value_t> R_norms(alloc, stream, n);

CUDA_CHECK(
cudaMemsetAsync(Q_sq_norms.data(), 0, Q_sq_norms.size() * sizeof(value_t)));
CUDA_CHECK(
cudaMemsetAsync(R_sq_norms.data(), 0, R_sq_norms.size() * sizeof(value_t)));

CUDA_CHECK(
cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t)));
CUDA_CHECK(
cudaMemsetAsync(R_norms.data(), 0, R_norms.size() * sizeof(value_t)));

compute_row_norm_kernel<<<raft::ceildiv(Q_nnz, tpb), tpb, 0, stream>>>(
Q_sq_norms.data(), Q_coo_rows, Q_data, Q_nnz);
compute_row_norm_kernel<<<raft::ceildiv(R_nnz, tpb), tpb, 0, stream>>>(
R_sq_norms.data(), R_coo_rows, R_data, R_nnz);

compute_row_sum_kernel<<<raft::ceildiv(Q_nnz, tpb), tpb, 0, stream>>>(
Q_norms.data(), Q_coo_rows, Q_data, Q_nnz);
compute_row_sum_kernel<<<raft::ceildiv(R_nnz, tpb), tpb, 0, stream>>>(
R_norms.data(), R_coo_rows, R_data, R_nnz);

compute_correlation(out, Q_sq_norms.data(), R_sq_norms.data(), Q_norms.data(),
R_norms.data(), m, n, n_cols, stream);
}

/**
* L2 distance using the expanded form: sum(x_k)^2 + sum(y_k)^2 - 2 * sum(x_k * y_k)
* The expanded form is more efficient for sparse data.
Expand Down Expand Up @@ -183,6 +261,40 @@ class l2_sqrt_expanded_distances_t
~l2_sqrt_expanded_distances_t() = default;
};

template <typename value_idx, typename value_t>
class correlation_expanded_distances_t : public distances_t<value_t> {
public:
explicit correlation_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config), ip_dists(config) {}

void compute(value_t *out_dists) {
ip_dists.compute(out_dists);

value_idx *b_indices = ip_dists.b_rows_coo();
value_t *b_data = ip_dists.b_data_coo();

raft::mr::device::buffer<value_idx> search_coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
config_->a_nnz);
raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows,
search_coo_rows.data(), config_->a_nnz,
config_->handle.get_stream());

compute_corr(out_dists, search_coo_rows.data(), config_->a_data,
config_->a_nnz, b_indices, b_data, config_->b_nnz,
config_->a_nrows, config_->b_nrows, config_->b_ncols,
config_->handle.get_device_allocator(),
config_->handle.get_stream());
}

~correlation_expanded_distances_t() = default;

protected:
const distances_config_t<value_idx, value_t> *config_;
ip_distances_t<value_idx, value_t> ip_dists;
};

/**
* Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * sqrt(sum(y_k)^2)))
* The expanded form is more efficient for sparse data.
Expand Down Expand Up @@ -282,6 +394,43 @@ class hellinger_expanded_distances_t : public distances_t<value_t> {
raft::mr::device::buffer<char> workspace;
};

template <typename value_idx = int, typename value_t = float>
class russelrao_expanded_distances_t : public distances_t<value_t> {
public:
explicit russelrao_expanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config),
workspace(config.handle.get_device_allocator(),
config.handle.get_stream(), 0),
ip_dists(config) {}

void compute(value_t *out_dists) {
ip_dists.compute(out_dists);

value_t n_cols = config_->a_ncols;
value_t n_cols_inv = 1.0 / n_cols;
raft::linalg::unaryOp<value_t>(
out_dists, out_dists, config_->a_nrows * config_->b_nrows,
[=] __device__(value_t input) { return (n_cols - input) * n_cols_inv; },
config_->handle.get_stream());

auto exec_policy = rmm::exec_policy(config_->handle.get_stream());
auto diags = thrust::counting_iterator<value_idx>(0);
value_idx b_nrows = config_->b_nrows;
thrust::for_each(exec_policy, diags, diags + config_->a_nrows,
[=] __device__(value_idx input) {
out_dists[input * b_nrows + input] = 0.0;
});
}

~russelrao_expanded_distances_t() = default;

private:
const distances_config_t<value_idx, value_t> *config_;
raft::mr::device::buffer<char> workspace;
ip_distances_t<value_idx, value_t> ip_dists;
};

}; // END namespace distance
}; // END namespace sparse
}; // END namespace raft
90 changes: 90 additions & 0 deletions cpp/include/raft/sparse/distance/lp_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,96 @@ class lp_unexpanded_distances_t : public distances_t<value_t> {
const distances_config_t<value_idx, value_t> *config_;
value_t p;
};

template <typename value_idx = int, typename value_t = float>
class hamming_unexpanded_distances_t : public distances_t<value_t> {
public:
explicit hamming_unexpanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config) {}

void compute(value_t *out_dists) {
unexpanded_lp_distances<value_idx, value_t>(out_dists, config_, NotEqual(),
Sum(), AtomicAdd());

value_t n_cols = 1.0 / config_->a_ncols;
raft::linalg::unaryOp<value_t>(
out_dists, out_dists, config_->a_nrows * config_->b_nrows,
[=] __device__(value_t input) { return input * n_cols; },
config_->handle.get_stream());
}

private:
const distances_config_t<value_idx, value_t> *config_;
};

template <typename value_idx = int, typename value_t = float>
class jensen_shannon_unexpanded_distances_t : public distances_t<value_t> {
public:
explicit jensen_shannon_unexpanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config) {}

void compute(value_t *out_dists) {
unexpanded_lp_distances<value_idx, value_t>(
out_dists, config_,
[] __device__(value_t a, value_t b) {
value_t m = 0.5f * (a + b);
bool a_zero = a == 0;
bool b_zero = b == 0;

value_t x = (!a_zero * m) / (a_zero + a);
value_t y = (!b_zero * m) / (b_zero + b);

bool x_zero = x == 0;
bool y_zero = y == 0;

return (-a * (!x_zero * log(x + x_zero))) +
(-b * (!y_zero * log(y + y_zero)));
},
Sum(), AtomicAdd());

raft::linalg::unaryOp<value_t>(
out_dists, out_dists, config_->a_nrows * config_->b_nrows,
[=] __device__(value_t input) { return sqrt(0.5 * input); },
config_->handle.get_stream());
}

private:
const distances_config_t<value_idx, value_t> *config_;
};

template <typename value_idx = int, typename value_t = float>
class kl_divergence_unexpanded_distances_t : public distances_t<value_t> {
public:
explicit kl_divergence_unexpanded_distances_t(
const distances_config_t<value_idx, value_t> &config)
: config_(&config) {}

void compute(value_t *out_dists) {
raft::mr::device::buffer<value_idx> coo_rows(
config_->handle.get_device_allocator(), config_->handle.get_stream(),
max(config_->b_nnz, config_->a_nnz));

raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows,
coo_rows.data(), config_->b_nnz,
config_->handle.get_stream());

balanced_coo_pairwise_generalized_spmv<value_idx, value_t>(
out_dists, *config_, coo_rows.data(),
[] __device__(value_t a, value_t b) { return a * log(a / b); }, Sum(),
AtomicAdd());

raft::linalg::unaryOp<value_t>(
out_dists, out_dists, config_->a_nrows * config_->b_nrows,
[=] __device__(value_t input) { return 0.5 * input; },
config_->handle.get_stream());
}

private:
const distances_config_t<value_idx, value_t> *config_;
};

}; // END namespace distance
}; // END namespace sparse
}; // END namespace raft
Loading