Skip to content

Commit

Permalink
Porting over recent updates to distance prim from cuml (#172)
Browse files Browse the repository at this point in the history
The sparse distance primitive was recently removed from cuml but Cosine and Jaccard were updated to support empty rows shortly before being removed. This PR ports those changes to the RAFT versions once and for all.

Authors:
  - Corey J. Nolet (@cjnolet)

Approvers:
  - Micka (@lowener)
  - John Zedlewski (@JohnZed)

URL: #172
  • Loading branch information
cjnolet authored Mar 17, 2021
1 parent fc46618 commit 2ef0a51
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 49 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(NOT CUB_IS_PART_OF_CTK)
set(CUB_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub CACHE STRING "Path to cub repo")
ExternalProject_Add(cub
GIT_REPOSITORY https://github.com/thrust/cub.git
GIT_TAG 1.12.0
GIT_TAG 1.8.0
PREFIX ${CUB_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
8 changes: 7 additions & 1 deletion cpp/include/raft/sparse/distance/bin_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ class jaccard_expanded_distances_t : public distances_t<value_t> {
config_->handle, config_->allocator, config_->stream,
[] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) {
value_t q_r_union = q_norm + r_norm;
return 1 - (dot / (q_r_union - dot));
value_t denom = q_r_union - dot;

value_t jacc = ((denom != 0) * dot) / ((denom == 0) + denom);

// flip the similarity when both rows are 0
bool both_empty = q_r_union == 0;
return 1 - ((!both_empty * jacc) + both_empty);
});
}

Expand Down
5 changes: 4 additions & 1 deletion cpp/include/raft/sparse/distance/l2_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ class cosine_expanded_distances_t : public distances_t<value_t> {
value_t norms = sqrt(q_norm) * sqrt(r_norm);
// deal with potential for 0 in denominator by forcing 0/1 instead
value_t cos = ((norms != 0) * dot) / ((norms == 0) + norms);
return 1 - cos;

// flip the similarity when both rows are 0
bool both_empty = (q_norm == 0) && (r_norm == 0);
return 1 - ((!both_empty * cos) + both_empty);
});
}

Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/sparse/selection/knn_graph.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ void knn_graph(const handle_t &handle, const value_t *X, size_t m, size_t n,
raft::sparse::COO<value_t, value_idx> &out, int c = 15) {
int k = build_k(m, c);

printf("K=%d\n", k);

auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();

Expand Down
6 changes: 0 additions & 6 deletions cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ void brute_force_knn_impl(
userStream, trans.data());
}

raft::print_device_vector("before sqrt", res_D, n * k, std::cout);

// Perform necessary post-processing
if ((m == faiss::MetricType::METRIC_L2 ||
m == faiss::MetricType::METRIC_Lp) &&
Expand All @@ -346,10 +344,6 @@ void brute_force_knn_impl(
[p] __device__(float input) { return powf(input, p); }, userStream);
}

CUDA_CHECK(cudaStreamSynchronize(userStream));

raft::print_device_vector("after sqrt", res_D, n * k, std::cout);

query_metric_processor->revert(search_items);
query_metric_processor->postprocess(out_D);
for (size_t i = 0; i < input.size(); i++) {
Expand Down
53 changes: 15 additions & 38 deletions cpp/test/sparse/connect_components.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ struct ConnectComponentsInputs {
value_idx n_col;
std::vector<value_t> data;

std::vector<value_idx> expected_labels;

int n_clusters;

int c;
};

Expand All @@ -67,20 +63,14 @@ class ConnectComponentsTest : public ::testing::TestWithParam<
params = ::testing::TestWithParam<
ConnectComponentsInputs<value_t, value_idx>>::GetParam();

out_edges = new raft::sparse::COO<value_t, value_idx>(
raft::sparse::COO<value_t, value_idx> out_edges(
handle.get_device_allocator(), handle.get_stream());

rmm::device_uvector<value_t> data(params.n_row * params.n_col,
handle.get_stream());

// Allocate result labels and expected labels on device
raft::allocate(labels, params.n_row);
raft::allocate(labels_ref, params.n_row);

raft::copy(data.data(), params.data.data(), data.size(),
handle.get_stream());
raft::copy(labels_ref, params.expected_labels.data(), params.n_row,
handle.get_stream());

rmm::device_uvector<value_idx> indptr(params.n_row + 1, stream);

Expand Down Expand Up @@ -114,33 +104,34 @@ class ConnectComponentsTest : public ::testing::TestWithParam<
handle, indptr.data(), knn_graph_coo.cols(), knn_graph_coo.vals(),
params.n_row, knn_graph_coo.nnz, colors.data(), stream, false);

raft::print_device_vector("colors", colors.data(), colors.size(),
std::cout);
CUDA_CHECK(cudaStreamSynchronize(stream));

printf("Got here.\n");

raft::print_device_vector("colors", colors.data(), params.n_row, std::cout);

/**
* 3. connect_components to fix connectivities
*/
raft::linkage::connect_components<value_idx, value_t>(
handle, *out_edges, data.data(), colors.data(), params.n_row,
handle, out_edges, data.data(), colors.data(), params.n_row,
params.n_col);

int final_nnz = out_edges->nnz + mst_coo.n_edges;
int final_nnz = out_edges.nnz + mst_coo.n_edges;

mst_coo.src.resize(final_nnz, stream);
mst_coo.dst.resize(final_nnz, stream);
mst_coo.weights.resize(final_nnz, stream);

printf("New nnz: %d\n", final_nnz);

/**
* Construct final edge list
*/
raft::copy_async(mst_coo.src.data() + mst_coo.n_edges, out_edges->rows(),
out_edges->nnz, stream);
raft::copy_async(mst_coo.dst.data() + mst_coo.n_edges, out_edges->cols(),
out_edges->nnz, stream);
raft::copy_async(mst_coo.weights.data() + mst_coo.n_edges,
out_edges->vals(), out_edges->nnz, stream);
raft::copy_async(mst_coo.src.data() + mst_coo.n_edges, out_edges.rows(),
out_edges.nnz, stream);
raft::copy_async(mst_coo.dst.data() + mst_coo.n_edges, out_edges.cols(),
out_edges.nnz, stream);
raft::copy_async(mst_coo.weights.data() + mst_coo.n_edges, out_edges.vals(),
out_edges.nnz, stream);

raft::sparse::COO<value_t, value_idx> final_coo(d_alloc, stream);
raft::sparse::linalg::symmetrize(
Expand Down Expand Up @@ -168,22 +159,15 @@ class ConnectComponentsTest : public ::testing::TestWithParam<

CUDA_CHECK(cudaStreamSynchronize(stream));

printf("output edges: %d\n", output_mst.n_edges);

final_edges = output_mst.n_edges;
}

void SetUp() override { basicTest(); }

void TearDown() override {
// CUDA_CHECK(cudaFree(labels));
// CUDA_CHECK(cudaFree(labels_ref));
}
void TearDown() override {}

protected:
ConnectComponentsInputs<value_t, value_idx> params;
value_idx *labels, *labels_ref;
raft::sparse::COO<value_t, value_idx> *out_edges;

value_idx final_edges;
};
Expand All @@ -201,8 +185,6 @@ const std::vector<ConnectComponentsInputs<float, int>> fix_conn_inputsf2 = {
0.27864171, 0.70911132, 0.21338564, 0.32035554, 0.73788331, 0.46926692,
0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396,
0.76166195, 0.66613745},
{9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
10,
-1},
// Test n_points == 100
{100,
Expand Down Expand Up @@ -543,11 +525,6 @@ const std::vector<ConnectComponentsInputs<float, int>> fix_conn_inputsf2 = {
8.66342445e-01

},
{0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 4, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
10,
-4}};

typedef ConnectComponentsTest<int, float> ConnectComponentsTestF_Int;
Expand Down
16 changes: 16 additions & 0 deletions cpp/test/sparse/distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ class SparseDistanceTest
};

const std::vector<SparseDistanceInputs<int, float>> inputs_i32_f = {
{5,
{0, 0, 1, 2},

{1, 2},
{0.5, 0.5},
{0, 1, 1, 1, 0, 1, 1, 1, 0},
raft::distance::DistanceType::CosineExpanded,
0.0},
{5,
{0, 0, 1, 2},

{1, 2},
{1.0, 1.0},
{0, 1, 1, 1, 0, 1, 1, 1, 0},
raft::distance::DistanceType::JaccardExpanded,
0.0},
{2,
{0, 2, 4, 6, 8},
{0, 1, 0, 1, 0, 1, 0, 1}, // indices
Expand Down

0 comments on commit 2ef0a51

Please sign in to comment.