Skip to content

Commit

Permalink
Changing tests and bench distancetype naming
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Jan 28, 2021
1 parent 050f4d0 commit c96a8c4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cpp/src_prims/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ void pairwiseDistance(value_t *out,
raft::distance::DistanceType metric) {
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
// EucExpandedL2
// L2Expanded
l2_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::InnerProduct:
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/prims/sparse/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct SparseKNNInputs {
int batch_size_index = 2;
int batch_size_query = 2;

raft::distance::DistanceType metric = raft::distance::DistanceType::EucUnexpandedL2;
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded;
};

template <typename value_idx, typename value_t>
Expand Down Expand Up @@ -165,7 +165,7 @@ const std::vector<SparseKNNInputs<int, float>> inputs_i32_f = {
{0, 3, 1, 0, 2, 0, 3, 0}, // inds
2,
2,
raft::distance::DistanceType::EucUnexpandedL2}};
raft::distance::DistanceType::L2Unexpanded}};
typedef SparseKNNTest<int, float> KNNTestF;
TEST_P(KNNTestF, Result) { compare(); }
INSTANTIATE_TEST_CASE_P(SparseKNNTest, KNNTestF,
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/sg/umap_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class UMAPTest : public ::testing::Test {

CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));

xformed_score = trustworthiness_score<float, EucUnexpandedL2Sqrt>(
xformed_score = trustworthiness_score<float, L2SqrtUnexpanded>(
handle, X_d.data(), xformed.data(), n_samples, n_features,
umap_params->n_components, umap_params->n_neighbors);
}
Expand Down Expand Up @@ -117,7 +117,7 @@ class UMAPTest : public ::testing::Test {

CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));

fit_score = trustworthiness_score<float, EucUnexpandedL2Sqrt>(
fit_score = trustworthiness_score<float, L2SqrtUnexpanded>(
handle, X_d.data(), embeddings.data(), n_samples, n_features,
umap_params->n_components, umap_params->n_neighbors);
}
Expand Down Expand Up @@ -154,7 +154,7 @@ class UMAPTest : public ::testing::Test {

CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));

supervised_score = trustworthiness_score<float, EucUnexpandedL2Sqrt>(
supervised_score = trustworthiness_score<float, L2SqrtUnexpanded>(
handle, X_d.data(), embeddings.data(), n_samples, n_features,
umap_params->n_components, umap_params->n_neighbors);
}
Expand Down Expand Up @@ -213,7 +213,7 @@ class UMAPTest : public ::testing::Test {

CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));

fit_with_knn_score = trustworthiness_score<float, EucUnexpandedL2Sqrt>(
fit_with_knn_score = trustworthiness_score<float, L2SqrtUnexpanded>(
handle, X_d.data(), embeddings.data(), n_samples, n_features,
umap_params->n_components, umap_params->n_neighbors);
}
Expand Down

0 comments on commit c96a8c4

Please sign in to comment.