Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Expose sparse distances via semiring to Python API #3516

Merged
merged 29 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0eadbbe
Initial work on pairwise distance python api
lowener Feb 18, 2021
3ad4dc7
Added function for python sparse pairwiseDist
lowener Feb 19, 2021
65989dc
Fix mistake on type for pairwise dist
lowener Feb 22, 2021
6abccb0
Changed sparse_pairwise prototype & modified distances to dict
lowener Feb 28, 2021
7c0ba8e
Merge branch 'branch-0.19' into 019-expose-spmv
lowener Mar 1, 2021
10d41aa
Fixed error in struct creation for pairwise dist
lowener Mar 1, 2021
4ca0abd
Merge branch 'branch-0.19' into 019-expose-spmv
lowener Mar 3, 2021
34007a4
Fixed configuration struct and naming and tests
lowener Mar 4, 2021
aeb41c3
Fix style
lowener Mar 5, 2021
08efba1
Fix style & reduce test precision
lowener Mar 5, 2021
e79b835
Augment test precision
lowener Mar 5, 2021
41adb22
Fix style
lowener Mar 6, 2021
55b4c83
Changed function naming
lowener Mar 9, 2021
3289196
Added Jaccard data conversion to boolean
lowener Mar 12, 2021
40799e6
fixing precision for sqrt distances
lowener Mar 15, 2021
524f2cb
Added doc for sparse pairwise dist
lowener Mar 15, 2021
0a370a7
Merge branch 'branch-0.19' into 019-expose-spmv
lowener Mar 16, 2021
28d16e6
Added tests on sparse pairwise dist
lowener Mar 16, 2021
e04b8fa
Merge branch 'branch-0.19' into 019-expose-spmv
lowener Mar 17, 2021
1ce7816
Fixed sparse_pairwise after review, added test on output type
lowener Mar 23, 2021
5b0a3ce
Fix style
lowener Mar 23, 2021
ea197ce
fix style
lowener Mar 23, 2021
893d648
Added data normalisation for hellinger
lowener Mar 24, 2021
3fb04d0
Fix style
lowener Mar 24, 2021
da5a9be
Restored float32 and int32 test
lowener Mar 24, 2021
ab122b0
Add dice distance and fix precision
lowener Mar 30, 2021
0e90f8c
Merge branch 'branch-0.19' into 019-expose-spmv
lowener Mar 30, 2021
5cbf694
fix style
lowener Mar 30, 2021
4d1c1c9
Update raft, use cuml normalization, add hellinger to neighors
lowener Apr 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG a57cf7df757b24230454e442c83f8491f97a4843
GIT_TAG d1fd927bc4ec67bfd765620b5fa93f17c54cfa70
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
13 changes: 13 additions & 0 deletions cpp/include/cuml/metrics/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,5 +321,18 @@ void pairwise_distance(const raft::handle_t &handle, const float *x,
raft::distance::DistanceType metric,
bool isRowMajor = true);

void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y,
double *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg);
void pairwiseDistance_sparse(const raft::handle_t &handle, float *x, float *y,
float *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg);

} // namespace Metrics
} // namespace ML
58 changes: 57 additions & 1 deletion cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,10 @@
* limitations under the License.
*/

#include <raft/sparse/distance/common.h>
#include <cuml/metrics/metrics.hpp>
#include <metrics/pairwise_distance.cuh>
#include <raft/sparse/distance/distance.cuh>

namespace ML {

Expand All @@ -37,5 +39,59 @@ void pairwise_distance(const raft::handle_t &handle, const float *x,
handle.get_stream(), isRowMajor);
}

template <typename value_idx = int, typename value_t = float>
void pairwiseDistance_sparse(const raft::handle_t &handle, value_t *x,
value_t *y, value_t *dist, value_idx x_nrows,
value_idx y_nrows, value_idx n_cols,
value_idx x_nnz, value_idx y_nnz,
value_idx *x_indptr, value_idx *y_indptr,
value_idx *x_indices, value_idx *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
raft::sparse::distance::distances_config_t<value_idx, value_t> dist_config;

dist_config.b_nrows = x_nrows;
dist_config.b_ncols = n_cols;
dist_config.b_nnz = x_nnz;
dist_config.b_indptr = x_indptr;
dist_config.b_indices = x_indices;
dist_config.b_data = x;

dist_config.a_nrows = y_nrows;
dist_config.a_ncols = n_cols;
dist_config.a_nnz = y_nnz;
dist_config.a_indptr = y_indptr;
dist_config.a_indices = y_indices;
dist_config.a_data = y;

dist_config.handle = handle.get_cusparse_handle();
dist_config.allocator = handle.get_device_allocator();
dist_config.stream = handle.get_stream();

raft::sparse::distance::pairwiseDistance(dist, dist_config, metric,
metric_arg);
}

void pairwiseDistance_sparse(const raft::handle_t &handle, float *x, float *y,
float *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
pairwiseDistance_sparse<int, float>(handle, x, y, dist, x_nrows, y_nrows,
n_cols, x_nnz, y_nnz, x_indptr, y_indptr,
x_indices, y_indices, metric, metric_arg);
}

void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y,
double *dist, int x_nrows, int y_nrows, int n_cols,
int x_nnz, int y_nnz, int *x_indptr, int *y_indptr,
int *x_indices, int *y_indices,
raft::distance::DistanceType metric,
float metric_arg) {
pairwiseDistance_sparse<int, double>(
handle, x, y, dist, x_nrows, y_nrows, n_cols, x_nnz, y_nnz, x_indptr,
y_indptr, x_indices, y_indices, metric, metric_arg);
}
} // namespace Metrics
} // namespace ML
6 changes: 4 additions & 2 deletions python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
cython_mutual_info_score as mutual_info_score
from cuml.metrics.confusion_matrix import confusion_matrix
from cuml.metrics.cluster.entropy import cython_entropy as entropy
from cuml.metrics.pairwise_distances import pairwise_distances, \
PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import pairwise_distances
from cuml.metrics.pairwise_distances import sparse_pairwise_distances
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS
from cuml.metrics.hinge_loss import hinge_loss
Loading