Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ML::MetricType with raft::distance::DistanceType #3389

Merged
merged 38 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
962357f
Added raft::MetricType on multiple knn functions
lowener Jan 4, 2021
77decf2
Added raft distance to sparse functions & removed 'expanded forms' ar…
lowener Jan 4, 2021
9e33273
Continued conversion of Cpp code for KNN metrics
lowener Jan 6, 2021
a7e9093
Conversion of python side of knn
lowener Jan 8, 2021
134d251
Removed expanded argument of brute_force_knn on python side
lowener Jan 8, 2021
234898a
Removed 'Euc' from distance name to follow Raft
lowener Jan 20, 2021
5261873
Changed distance naming for prims processing
lowener Jan 20, 2021
050f4d0
Changing distance type in src_prims
lowener Jan 26, 2021
c96a8c4
Changing tests and bench distancetype naming
lowener Jan 27, 2021
fcfc8b7
Fixing compilation errors with rebase
lowener Jan 29, 2021
fc3cc75
Default L2 metric to Expanded
lowener Jan 29, 2021
50c29c6
Restoring 2 raft distance metrics to knn
lowener Jan 30, 2021
09200e0
Correct logic for unexpanded post-processing
lowener Jan 30, 2021
af05da6
Fixed logic for python side NN
lowener Jan 31, 2021
380883c
Merge branch 'branch-0.18' into 018_replace_distancetype
lowener Jan 31, 2021
2441139
Update copyrights & fix style
lowener Feb 1, 2021
a53ed06
Fix style
lowener Feb 1, 2021
521b983
Fix Distancetype call in tsne
lowener Feb 1, 2021
44dcb80
Fix copyright
lowener Feb 2, 2021
fa136e8
Change distancetype selection from python
lowener Feb 3, 2021
b3aa41e
fix style
lowener Feb 4, 2021
3eb8ccb
Adding L2Sqrt sparse distance class
lowener Feb 8, 2021
6f724b8
Fixed new l2 sqrt expanded distance
lowener Feb 8, 2021
241d913
Added l2SqrtUnexpanded distance
lowener Feb 8, 2021
ee20152
Add a constant for the supported sparse metrics
lowener Feb 9, 2021
d40a787
Re-factor distancetype & remove expanded of python knn
lowener Feb 9, 2021
01e09af
Merge branch 'branch-0.18' into 018_replace_distancetype
lowener Feb 11, 2021
debdb3a
Fix style
lowener Feb 11, 2021
5a694aa
Merge remote-tracking branch 'nvidia/branch-0.19' into 018_replace_di…
lowener Feb 11, 2021
e5f25f9
Fix style
lowener Feb 11, 2021
cc95e28
Fix distance type for umap algorighm
lowener Feb 11, 2021
11f668e
Changed DistanceType cast in C API
lowener Feb 19, 2021
a419589
Changed default metric from L2Unexpanded to L2Expanded
lowener Feb 22, 2021
0ff0bbe
Changed metric for tests according to the new DistanceType
lowener Feb 24, 2021
819991d
Merge branch 'branch-0.19' into 018_replace_distancetype
lowener Feb 24, 2021
6259a0b
Update copyright
lowener Feb 24, 2021
9c1b466
Fixed mistake during merge to branch-0.19
lowener Feb 25, 2021
03fd464
Temporarily disabling test_hinge_loss
lowener Mar 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,11 @@

#include <faiss/gpu/GpuIndex.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <raft/linalg/distance_type.h>
#include <cuml/common/logger.hpp>
#include <cuml/cuml.hpp>

namespace ML {

enum MetricType {
METRIC_INNER_PRODUCT = 0,
METRIC_L2,
METRIC_L1,
METRIC_Linf,
METRIC_Lp,

METRIC_Canberra = 20,
METRIC_BrayCurtis,
METRIC_JensenShannon,

METRIC_Cosine = 100,
METRIC_Correlation,
METRIC_Jaccard,
METRIC_Hellinger
};

struct knnIndex {
faiss::gpu::StandardGpuResources *gpu_res;
faiss::gpu::GpuIndex *index;
Expand Down Expand Up @@ -102,20 +85,19 @@ struct IVFSQParam : IVFParam {
* default
* @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
* @param[in] expanded should lp-based distances be returned in their expanded
* form (e.g., without raising to the 1/p power).
*/
void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
std::vector<int> &sizes, int D, float *search_items, int n,
int64_t *res_I, float *res_D, int k,
bool rowMajorIndex = false, bool rowMajorQuery = false,
MetricType metric = MetricType::METRIC_L2,
float metric_arg = 2.0f, bool expanded = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may want L2SqrtUnexpanded here to have the euclidean distance as default. At least if the results is seen by the end-user and not only used internally. Normally, METRIC_L2 in FAISS provides the euclidean distance before root-squaring. Then post-processing should apply the root-square. @cjnolet probably knows better about this though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave these in expanded form, actually. It's the most used metric and the difference in performance is pretty huge

float metric_arg = 2.0f);

void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, int D,
ML::MetricType metric, float metricArg,
float *index_items, int n);
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n);

void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k,
float *distances, int64_t *labels);
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuml/neighbors/knn_sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cusparse_v2.h>

#include <raft/linalg/distance_type.h>
#include <cuml/neighbors/knn.hpp>

namespace ML {
Expand All @@ -36,7 +37,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
float *output_dists, int k,
size_t batch_size_index = DEFAULT_BATCH_SIZE,
size_t batch_size_query = DEFAULT_BATCH_SIZE,
ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I would leave this in expanded form.

float metricArg = 0);
}; // end namespace Sparse
}; // end namespace ML
10 changes: 5 additions & 5 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace ML {
void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
std::vector<int> &sizes, int D, float *search_items, int n,
int64_t *res_I, float *res_D, int k, bool rowMajorIndex,
bool rowMajorQuery, MetricType metric, float metric_arg,
bool expanded) {
bool rowMajorQuery, raft::distance::DistanceType metric,
float metric_arg) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors must be the same size");

Expand All @@ -44,13 +44,13 @@ void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
input, sizes, D, search_items, n, res_I, res_D, k,
handle.get_device_allocator(), handle.get_stream(), int_streams.data(),
handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, nullptr,
metric, metric_arg, expanded);
metric, metric_arg);
}

void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, int D,
ML::MetricType metric, float metricArg,
float *index_items, int n) {
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n) {
MLCommon::Selection::approx_knn_build_index(
index, params, D, metric, metricArg, index_items, n, handle.get_stream());
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/knn/knn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes,
try {
ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n,
res_I, res_D, k, rowMajorIndex, rowMajorQuery,
(ML::MetricType)metric_type, metric_arg, expanded);
(raft::distance::DistanceType)metric_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep making this conversion explicit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using the same enum type everywhere now, I think this conversion can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion from int to DistanceType is needed to not modify the C API of knn_search. I changed it to a static_cast.

metric_arg);
} catch (...) {
status = CUML_ERROR_UNKNOWN;
}
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/knn/knn_sparse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
int n_query_rows, int n_query_cols, int *output_indices,
float *output_dists, int k,
size_t batch_size_index, // approx 1M
size_t batch_size_query, ML::MetricType metric,
float metricArg, bool expanded_form) {
size_t batch_size_query,
raft::distance::DistanceType metric, float metricArg) {
auto d_alloc = handle.get_device_allocator();
cusparseHandle_t cusparse_handle = handle.get_cusparse_handle();
cudaStream_t stream = handle.get_stream();
Expand All @@ -42,8 +42,7 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
idx_indptr, idx_indices, idx_data, idx_nnz, n_idx_rows, n_idx_cols,
query_indptr, query_indices, query_data, query_nnz, n_query_rows,
n_query_cols, output_indices, output_dists, k, cusparse_handle, d_alloc,
stream, batch_size_index, batch_size_query, metric, metricArg,
expanded_form);
stream, batch_size_index, batch_size_query, metric, metricArg);
}
}; // namespace Sparse
}; // namespace ML
4 changes: 2 additions & 2 deletions cpp/src/tsne/distances.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle,
k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors,
handle.get_cusparse_handle(), handle.get_device_allocator(), stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need L2SqrtUnexpanded here, unless this distance value doesn't reach user's eye (only used by TSNE internally).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't reach the users so I think it's okay not to use the sqrt here. I think the expanded form is also good to use here for speed.

}

// sparse, int64
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuml/manifold/umapparams.h>
#include <raft/linalg/distance_type.h>
#include <cuml/manifold/common.hpp>
#include <cuml/neighbors/knn_sparse.hpp>
#include <iostream>
Expand Down Expand Up @@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle,
inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors,
handle.get_cusparse_handle(), d_alloc, stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave all of these in expanded form. The unexpanded is more stable under some conditions but in general it's a better (and faster) starting point.

}

template <>
Expand Down
59 changes: 44 additions & 15 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <thrust/device_vector.h>
#include <thrust/iterator/transform_iterator.h>

#include <raft/linalg/distance_type.h>
#include "processing.cuh"

#include <cuml/common/cuml_allocator.hpp>
Expand Down Expand Up @@ -186,14 +187,37 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK,
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

inline faiss::MetricType build_faiss_metric(ML::MetricType metric) {
inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case ML::MetricType::METRIC_Cosine:
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case ML::MetricType::METRIC_Correlation:
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
return (faiss::MetricType)metric;
THROW("MetricType not supported: %d", metric);
}
Comment on lines +190 to 221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this. It's a lot better and avoids future issues (in case enums are modified). However for now, no post-processing is applied to distances produced by approximate nearest neighbors. It means that only the raw FAISS metrics should be available. It is actually an issue that I should probably work on next.

}

Expand All @@ -219,7 +243,8 @@ inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype(

template <typename IntType = int>
void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params,
IntType D, ML::MetricType metric,
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = index->device;
Expand All @@ -232,7 +257,9 @@ void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params,

template <typename IntType = int>
void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params,
IntType D, ML::MetricType metric, IntType n) {
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFPQConfig config;
config.device = index->device;
config.usePrecomputedTables = params->usePrecomputedTables;
Expand All @@ -246,7 +273,9 @@ void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params,

template <typename IntType = int>
void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,
IntType D, ML::MetricType metric, IntType n) {
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config;
config.device = index->device;
faiss::MetricType faiss_metric = build_faiss_metric(metric);
Expand All @@ -262,8 +291,8 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,

template <typename IntType = int>
void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,
IntType D, ML::MetricType metric, float metricArg,
float *index_items, IntType n,
IntType D, raft::distance::DistanceType metric,
float metricArg, float *index_items, IntType n,
cudaStream_t userStream) {
int device;
CUDA_CHECK(cudaGetDevice(&device));
Expand Down Expand Up @@ -330,7 +359,6 @@ void approx_knn_search(ML::knnIndex *index, IntType n, const float *x,
* non-contiguous partitions
* @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
* @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root
*/
template <typename IntType = int>
void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
Expand All @@ -342,8 +370,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
int n_int_streams = 0, bool rowMajorIndex = true,
bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false) {
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

float metricArg = 0) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors should be the same size");

Expand Down Expand Up @@ -452,9 +481,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
}

// Perform necessary post-processing
if ((m == faiss::MetricType::METRIC_L2 ||
m == faiss::MetricType::METRIC_Lp) &&
!expanded_form) {
if (metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::LpUnexpanded) {
Comment on lines +484 to +486
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be wrong, but I think that in this case, only unexpanded forms (that need post-procesing) : L2SqrtUnexpanded and LpUnexpanded should have post-processing. @cjnolet probably knows more about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FAISS only supports the expanded form but I believe we're converting both the Expanded and Unexpanded L2 forms into faiss::METRIC_L2 so we'll need to sqrt both of them.

/**
* post-processing
*/
Expand Down
6 changes: 3 additions & 3 deletions cpp/src_prims/selection/processing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,17 @@ class DefaultMetricProcessor : public MetricProcessor<math_t> {

template <typename math_t>
inline std::unique_ptr<MetricProcessor<math_t>> create_processor(
ML::MetricType metric, int n, int D, int k, bool rowMajorQuery,
raft::distance::DistanceType metric, int n, int D, int k, bool rowMajorQuery,
cudaStream_t userStream, std::shared_ptr<deviceAllocator> allocator) {
MetricProcessor<math_t> *mp = nullptr;

switch (metric) {
case ML::MetricType::METRIC_Cosine:
case raft::distance::DistanceType::CosineExpanded:
mp = new CosineMetricProcessor<math_t>(n, D, k, rowMajorQuery, userStream,
allocator);
break;

case ML::MetricType::METRIC_Correlation:
case raft::distance::DistanceType::CorrelationExpanded:
mp = new CorrelationMetricProcessor<math_t>(n, D, k, rowMajorQuery,
userStream, allocator);
break;
Expand Down
23 changes: 23 additions & 0 deletions cpp/src_prims/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <unordered_set>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand All @@ -42,6 +43,20 @@ namespace raft {
namespace sparse {
namespace distance {

static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::L2Expanded,
raft::distance::DistanceType::L2Unexpanded,
raft::distance::DistanceType::L2SqrtExpanded,
raft::distance::DistanceType::L2SqrtUnexpanded,
raft::distance::DistanceType::InnerProduct,
raft::distance::DistanceType::L1,
raft::distance::DistanceType::Canberra,
raft::distance::DistanceType::Linf,
raft::distance::DistanceType::LpUnexpanded,
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded};

/**
* Compute pairwise distances between A and B, using the provided
* input configuration and distance function.
Expand All @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::L2Expanded:
l2_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
l2_sqrt_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::InnerProduct:
ip_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2Unexpanded:
l2_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
l2_sqrt_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::L1:
l1_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
Expand Down
Loading