Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatch based on compute architecture #1335

Merged
merged 77 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
8d3e8a0
contractions: Concentrate tile index calculations
ahendriksen Sep 2, 2022
cb7baab
pairwise_distance_base: Remove all ldgXY(0) calls
ahendriksen Sep 2, 2022
066bf3b
pairwise_distance_base: Move all logic into run loop
ahendriksen Sep 2, 2022
a15d5fc
pairwise_distance_base: Fix typo
ahendriksen Oct 5, 2022
71c6da6
Remove deprecated header
ahendriksen Jan 11, 2023
4bbedf6
Replace lambdas by raft::void_op
ahendriksen Jan 12, 2023
c3d1f6e
Use an operator for L1 distance
ahendriksen Jan 12, 2023
3e3478b
Add launch function
ahendriksen Jan 12, 2023
264a9d2
l1: Replace run-time -> compile-time dispatch
ahendriksen Jan 13, 2023
b232057
pairwise matrix: move files into subdirectories
ahendriksen Jan 13, 2023
06f6ffa
pairwise matrix: Untangle dispatching and kernel template parameters
ahendriksen Jan 13, 2023
2f41faa
l2 unexp: Use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7938614
l2 exp: Use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7afe6cc
Add template for distance operator
ahendriksen Jan 13, 2023
5fe3292
Reenable cutlass-based kernels for CUDA 12.0
ahendriksen Jan 13, 2023
c623332
pairwise matrix l2: Add support for CUTLASS kernels
ahendriksen Jan 13, 2023
27511fc
Canberra: use dispatching mechanism
ahendriksen Jan 13, 2023
58ce6f8
Chebyshev: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
d397c17
Correlation: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7005a4f
Hamming: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7831deb
Hellinger: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
4dc72ce
Jensen-Shannon: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
b0d36c1
remove old hamming code
ahendriksen Jan 13, 2023
e95a65b
KL divergence: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
f1c105b
Minkowski: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
ac66e3f
Russel-Rao: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
a89896a
Cosine: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
16b2acd
Fix include for l1 op
ahendriksen Jan 13, 2023
1326e34
kl_divergence: Use raft::log instead of raft::myLog
ahendriksen Feb 10, 2023
0169b26
distance_op: Add expensive_inner_loop marker
ahendriksen Feb 10, 2023
52e95e1
Update copyright notices
ahendriksen Feb 10, 2023
28cd57b
Reusable dispatch mechanism
ahendriksen Feb 10, 2023
c44aece
Dispatch mechanism using switch statement
ahendriksen Feb 10, 2023
7c3bd76
Remove one ".template" from kernel_sm60
ahendriksen Feb 10, 2023
d62eeb7
Dispatch on veclen instead of byte_alignment
ahendriksen Feb 10, 2023
5c3dcaf
Use many template parameters again
ahendriksen Feb 20, 2023
2613e8a
Remove duplicate DistanceType enum definition
ahendriksen Feb 20, 2023
62ed53a
Remove pairwiseDistanceMatKernel
ahendriksen Feb 20, 2023
c334ba3
Remove distance::detail::pairwise_distance_impl
ahendriksen Feb 20, 2023
8e43238
distance_ops: Include cuda_utils.cuh
ahendriksen Feb 21, 2023
e176351
Replace DistanceImpl with method overloads
ahendriksen Feb 21, 2023
6ddd14f
Remove impl files and move doc strings
ahendriksen Feb 21, 2023
34ccddc
Update readme
ahendriksen Feb 21, 2023
b27cdca
Merge branch 'rapids/branch-23.04' into wip-refactor-distance
ahendriksen Feb 21, 2023
6a12ded
Reenable device code generation
ahendriksen Feb 21, 2023
486393e
Readd overload of raft::distance::detail::distance
ahendriksen Feb 21, 2023
ca29e2d
Fix style
ahendriksen Feb 21, 2023
28c95a1
Fix 11.8 compilation error
ahendriksen Feb 22, 2023
a5592b9
Rename minkowski -> lp_unexp
ahendriksen Feb 22, 2023
265ba07
Rename Chebyshev -> l_inf
ahendriksen Feb 22, 2023
7ccb8a7
Rename euc -> l2
ahendriksen Feb 22, 2023
874d014
Update copyright headers
ahendriksen Feb 22, 2023
757fb44
Remove misleading note about workspace nullptr
ahendriksen Feb 22, 2023
d6e9261
Remove notes file
ahendriksen Feb 22, 2023
885bda6
Put template on struct instead of methods
ahendriksen Feb 22, 2023
cd38ec6
Fix style
ahendriksen Feb 22, 2023
749d000
Add dispatch based on compute architecture
ahendriksen Feb 22, 2023
7262861
Fix style
ahendriksen Feb 22, 2023
e7a8e89
Merge branch 'branch-23.04' into wip-refactor-distance
cjnolet Feb 22, 2023
1ef8520
Merge remote-tracking branch 'rapids/pull-request/1142' into enh-arch…
ahendriksen Feb 23, 2023
09a3050
Fix linker error: multiple definition..
ahendriksen Mar 6, 2023
6467221
Update cpp/include/raft/distance/detail/distance_ops/canberra.cuh
ahendriksen Mar 6, 2023
a83461e
Update cpp/include/raft/distance/detail/distance.cuh
ahendriksen Mar 6, 2023
393edf3
Add note about alignment in case of byte input
ahendriksen Mar 6, 2023
48a0c21
Fix
ahendriksen Mar 7, 2023
f8daf48
Merge remote-tracking branch 'rapids/pull-request/1142' into enh-arch…
ahendriksen Mar 7, 2023
1a6636f
Implement review feedback
ahendriksen Mar 7, 2023
e696f24
Merge branch 'branch-23.04' into enh-arch-dispatch
ahendriksen Mar 10, 2023
3164802
Determine runtime arch using kernel pointer
ahendriksen Mar 13, 2023
35014e6
Fix Gram compilation error
ahendriksen Mar 14, 2023
1516198
Reformat comments
ahendriksen Mar 14, 2023
37cdc51
Merge branch 'branch-23.04' into enh-arch-dispatch
ahendriksen Mar 15, 2023
a8d98ca
Merge branch 'branch-23.04' into enh-arch-dispatch
cjnolet Mar 15, 2023
c9ab1b8
Fix col major errors on SM80
ahendriksen Mar 16, 2023
584897f
Merge remote-tracking branch 'rapids/branch-23.04' into enh-arch-disp…
ahendriksen Mar 16, 2023
30c3391
Fix build failure
ahendriksen Mar 17, 2023
9eaf9b5
Fix spelling
ahendriksen Mar 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 20 additions & 82 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,12 @@

#include <raft/core/operators.hpp>

#include <raft/distance/detail/distance_ops/canberra.cuh>
#include <raft/distance/detail/distance_ops/correlation.cuh>
#include <raft/distance/detail/distance_ops/cosine.cuh>
#include <raft/distance/detail/distance_ops/hamming.cuh>
#include <raft/distance/detail/distance_ops/hellinger.cuh>
#include <raft/distance/detail/distance_ops/jensen_shannon.cuh>
#include <raft/distance/detail/distance_ops/kl_divergence.cuh>
#include <raft/distance/detail/distance_ops/l1.cuh>
#include <raft/distance/detail/distance_ops/l2_exp.cuh>
#include <raft/distance/detail/distance_ops/l2_unexp.cuh>
#include <raft/distance/detail/distance_ops/l_inf.cuh>
#include <raft/distance/detail/distance_ops/lp_unexp.cuh>
#include <raft/distance/detail/distance_ops/russel_rao.cuh>

#include <raft/distance/detail/distance_ops/all_ops.cuh>
#include <raft/distance/detail/pairwise_matrix/dispatch.cuh>

#include <raft/distance/distance_types.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/util/arch.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -126,7 +114,7 @@ void distance_impl(raft::resources const& handle,
const DataT* y_norm = nullptr;

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -205,7 +193,7 @@ void distance_impl(raft::resources const& handle,

using OpT = ops::correlation_distance_op<DataT, AccT, IdxT>;
OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k);
distance_matrix_dispatch<decltype(corr_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(corr_op), DataT, AccT, OutT, FinOpT, IdxT>(
corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -248,34 +236,9 @@ void distance_impl(raft::resources const& handle,
norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
}

// On CUDA 12:
// - always execute normal kernel
//
// On CUDA 11 and below:
// - execute CUTLASS-based kernel on SM_80 and above
// - execute normal kernel otherwise.

if constexpr (__CUDACC_VER_MAJOR__ == 12) {
// Always execute legacy kernels on CUDA 12
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using Op = ops::cosine_cutlass_op<DataT, AccT>;
Op distance_op{};

distance_matrix_cutlass_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
// Else use "legacy" cosine kernel
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}
}
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
Expand All @@ -300,7 +263,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -362,7 +325,7 @@ void distance_impl(raft::resources const& handle,
const DataT* x_norm = nullptr;
const DataT* y_norm = nullptr;

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);

// Finally revert sqrt of x and y
Expand Down Expand Up @@ -394,7 +357,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -438,7 +401,7 @@ void distance_impl(raft::resources const& handle,
const DataT* x_norm = nullptr;
const DataT* y_norm = nullptr;

distance_matrix_dispatch<decltype(kl_divergence), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(kl_divergence), DataT, AccT, OutT, FinOpT, IdxT>(
kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);

if (x != y) {
Expand Down Expand Up @@ -469,7 +432,7 @@ void distance_impl(raft::resources const& handle,
const DataT* y_norm = nullptr;

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -514,34 +477,9 @@ void distance_impl_l2_expanded( // NOTE: different name
norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
}

// On CUDA 12:
// - always execute normal kernel
//
// On CUDA 11 and below:
// - execute CUTLASS-based kernel on SM_80 and above
// - execute normal kernel otherwise.

if constexpr (__CUDACC_VER_MAJOR__ == 12) {
// Always execute legacy kernels on CUDA 12
ops::l2_exp_distance_op<DataT, AccT, IdxT> l2_op(perform_sqrt);
distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using L2Op = ops::l2_exp_cutlass_op<DataT, AccT>;
L2Op l2_op(perform_sqrt);

distance_matrix_cutlass_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
// Else use "legacy" L2
ops::l2_exp_distance_op<DataT, AccT, IdxT> l2_op(perform_sqrt);
distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}
}
ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{perform_sqrt};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
Expand Down Expand Up @@ -610,7 +548,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}

Expand Down Expand Up @@ -638,7 +576,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
}

Expand All @@ -664,7 +602,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand All @@ -690,7 +628,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand All @@ -716,7 +654,7 @@ void distance_impl(raft::resources const& handle,

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

Expand Down
35 changes: 35 additions & 0 deletions cpp/include/raft/distance/detail/distance_ops/all_ops.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

// Defines a named requirement "has_cutlass_op"
#include <raft/distance/detail/distance_ops/cutlass.cuh>

// The distance operations:
#include <raft/distance/detail/distance_ops/canberra.cuh>
#include <raft/distance/detail/distance_ops/correlation.cuh>
#include <raft/distance/detail/distance_ops/cosine.cuh>
#include <raft/distance/detail/distance_ops/hamming.cuh>
#include <raft/distance/detail/distance_ops/hellinger.cuh>
#include <raft/distance/detail/distance_ops/jensen_shannon.cuh>
#include <raft/distance/detail/distance_ops/kl_divergence.cuh>
#include <raft/distance/detail/distance_ops/l1.cuh>
#include <raft/distance/detail/distance_ops/l2_exp.cuh>
#include <raft/distance/detail/distance_ops/l2_unexp.cuh>
#include <raft/distance/detail/distance_ops/l_inf.cuh>
#include <raft/distance/detail/distance_ops/lp_unexp.cuh>
#include <raft/distance/detail/distance_ops/russel_rao.cuh>
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops {
*
* c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| )
*/
template <typename DataT, typename AccT, typename IdxT>
template <typename DataType, typename AccType, typename IdxType>
struct canberra_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

// Load norms of input data
static constexpr bool use_norms = false;
// Whether the core function requires so many instructions that it makes sense
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ namespace raft::distance::detail::ops {
* /
* (|| x - mean(x) ||_2 || y - mean(y) ||_2)
*/
template <typename DataT, typename AccT, typename IdxT>
template <typename DataType, typename AccType, typename IdxType>
struct correlation_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

const DataT* x2n;
const DataT* y2n;
IdxT m;
Expand Down
27 changes: 17 additions & 10 deletions cpp/include/raft/distance/detail/distance_ops/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,30 @@

namespace raft::distance::detail::ops {

// Epilogue operator for CUTLASS based kernel
template <typename DataT, typename AccT>
struct cosine_cutlass_op {
__device__ cosine_cutlass_op() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - (AccT)(accVal / (aNorm * bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
};

/**
* @brief the expanded cosine distance matrix calculation
*
* It computes the following equation:
*
* d(x, y) = 1 - (x ⋅ y) / ( ||x||_2 ||y||_2)
*/
template <typename DataT, typename AccT, typename IdxT>
template <typename DataType, typename AccType, typename IdxType>
struct cosine_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

// Load norms of input data
static constexpr bool use_norms = true;
// Whether the core function requires so many instructions that it makes sense
Expand Down Expand Up @@ -60,16 +75,8 @@ struct cosine_distance_op {
}
}
}
};

template <typename DataT, typename AccT>
struct cosine_cutlass_op {
__device__ cosine_cutlass_op() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - (AccT)(accVal / (aNorm * bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
cosine_cutlass_op<DataT, AccT> get_cutlass_op() { return cosine_cutlass_op<DataT, AccT>(); }
};

} // namespace raft::distance::detail::ops
40 changes: 40 additions & 0 deletions cpp/include/raft/distance/detail/distance_ops/cutlass.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <type_traits>

namespace raft::distance::detail::ops {

// This file defines the named requirement "has_cutlass_op" that can be used to
// determine if a distance operation has a CUTLASS op that can be used to pass
// to CUTLASS. Examples of distance operations that satisfy this requirement are
// cosine_distance_op and l2_exp_distance_op.

// Primary template handles types that do not support CUTLASS.
// This pattern is described in:
// https://en.cppreference.com/w/cpp/types/void_t
template <typename, typename = void>
struct has_cutlass_op : std::false_type {
};

// Specialization recognizes types that do support CUTLASS
template <typename T>
struct has_cutlass_op<T, std::void_t<decltype(&T::get_cutlass_op)>> : std::true_type {
};

} // namespace raft::distance::detail::ops
6 changes: 5 additions & 1 deletion cpp/include/raft/distance/detail/distance_ops/hamming.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops {
*
* c_ij = sum_k (x_ik != y_kj) / k
*/
template <typename DataT, typename AccT, typename IdxT>
template <typename DataType, typename AccType, typename IdxType>
struct hamming_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

IdxT k;

hamming_distance_op(IdxT k_) noexcept : k(k_) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops {
* c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj))
*
*/
template <typename DataT, typename AccT, typename IdxT>
template <typename DataType, typename AccType, typename IdxType>
struct hellinger_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

// Load norms of input data
static constexpr bool use_norms = false;
// Whether the core function requires so many instructions that it makes sense
Expand Down
Loading