Skip to content

Commit

Permalink
Removing individual pairwise dists
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Oct 12, 2021
1 parent 0c7ae87 commit 6d9cde3
Show file tree
Hide file tree
Showing 26 changed files with 8 additions and 1,357 deletions.
12 changes: 0 additions & 12 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,6 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/kl_divergence.cu
src/metrics/mutual_info_score.cu
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_correlation.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hamming.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_jensen_shannon.cu
src/metrics/pairwise_distance_kl_divergence.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/pairwise_distance_russell_rao.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down
106 changes: 8 additions & 98 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,13 @@
* limitations under the License.
*/

#include <raft/sparse/distance/common.h>
#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.hpp>

#include <raft/handle.hpp>

#include <raft/distance/distance.hpp>
#include <raft/sparse/distance/common.h>
#include <raft/sparse/distance/distance.cuh>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
#include "pairwise_distance_cosine.cuh"
#include "pairwise_distance_euclidean.cuh"
#include "pairwise_distance_hamming.cuh"
#include "pairwise_distance_hellinger.cuh"
#include "pairwise_distance_jensen_shannon.cuh"
#include "pairwise_distance_kl_divergence.cuh"
#include "pairwise_distance_l1.cuh"
#include "pairwise_distance_minkowski.cuh"
#include "pairwise_distance_russell_rao.cuh"

namespace ML {

Expand All @@ -47,48 +37,8 @@ void pairwise_distance(const raft::handle_t& handle,
bool isRowMajor,
double metric_arg)
{
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
raft::distance::pairwise_distance<double, int>(
handle, x, y, dist, m, n, k, raft::distance::DistanceType::Canberra, isRowMajor);
}

void pairwise_distance(const raft::handle_t& handle,
Expand All @@ -102,48 +52,8 @@ void pairwise_distance(const raft::handle_t& handle,
bool isRowMajor,
float metric_arg)
{
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
raft::distance::pairwise_distance<float, int>(
handle, x, y, dist, m, n, k, raft::distance::DistanceType::Canberra, isRowMajor);
}

template <typename value_idx = int, typename value_t = float>
Expand Down
56 changes: 0 additions & 56 deletions cpp/src/metrics/pairwise_distance_canberra.cu

This file was deleted.

47 changes: 0 additions & 47 deletions cpp/src/metrics/pairwise_distance_canberra.cuh

This file was deleted.

54 changes: 0 additions & 54 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu

This file was deleted.

45 changes: 0 additions & 45 deletions cpp/src/metrics/pairwise_distance_chebyshev.cuh

This file was deleted.

Loading

0 comments on commit 6d9cde3

Please sign in to comment.