Skip to content

Commit

Permalink
moved kernel computation to public section
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 30, 2023
1 parent 3f61b64 commit 2b6090a
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 77 deletions.
17 changes: 8 additions & 9 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,7 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/detail/inner_product_double_double_double_int.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/neighbors/brute_force_knn_int64_t_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/specializations/detail/l1_float_float_float_int.cu
Expand All @@ -332,6 +323,14 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/distance/specializations/kernels/gram_matrix_base_double.cu
src/distance/specializations/kernels/gram_matrix_base_float.cu
src/distance/specializations/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/kernels/polynomial_kernel_float_int.cu
src/distance/specializations/kernels/rbf_kernel_double.cu
src/distance/specializations/kernels/rbf_kernel_float.cu
src/distance/specializations/kernels/tanh_kernel_double.cu
src/distance/specializations/kernels/tanh_kernel_float.cu
src/matrix/specializations/detail/select_k_float_uint32_t.cu
src/matrix/specializations/detail/select_k_float_int64_t.cu
src/matrix/specializations/detail/select_k_half_uint32_t.cu
Expand Down
12 changes: 2 additions & 10 deletions cpp/include/raft/distance/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,9 @@

#pragma once

#include <raft/distance/detail/kernels/gram_matrix.cuh>
#include <raft/distance/detail/kernels/kernel_factory.cuh>
#include <raft/distance/kernels/gram_matrix.cuh>
#include <raft/distance/kernels/kernel_factory.cuh>
#include <raft/util/cuda_utils.cuh>

#include <raft/distance/distance.cuh>
#include <raft/linalg/gemm.cuh>

namespace raft::distance::kernels {

// TODO: Need to expose formal APIs for this that are more consistent w/ other APIs in RAFT
using raft::distance::kernels::detail::GramMatrixBase;
using raft::distance::kernels::detail::KernelFactory;

}; // end namespace raft::distance::kernels
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
//#include <raft/sparse/detail/cusparse_wrappers.h>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/linalg/spmm.cuh>

#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/linalg/spmm.cuh>

namespace raft::distance::kernels::detail {
namespace raft::distance::kernels {

template <typename math_t>
using dense_input_matrix_view_t = raft::device_matrix_view<const math_t, int, layout_stride>;
Expand Down Expand Up @@ -507,4 +505,4 @@ class GramMatrixBase {
}
};

}; // end namespace raft::distance::kernels::detail
}; // end namespace raft::distance::kernels
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <raft/distance/distance_types.hpp>
#include <raft/util/cudart_utils.hpp>

namespace raft::distance::kernels::detail {
namespace raft::distance::kernels {

template <typename math_t>
class KernelFactory {
Expand Down Expand Up @@ -61,4 +61,4 @@ class KernelFactory {
}
};

}; // end namespace raft::distance::kernels::detail
}; // end namespace raft::distance::kernels
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <raft/linalg/gemm.cuh>
#include <raft/sparse/linalg/norm.cuh>

namespace raft::distance::kernels::detail {
namespace raft::distance::kernels {

/** Epiloge function for polynomial kernel without padding.
* Calculates output = (gain*in + offset)^exponent
Expand Down Expand Up @@ -738,4 +738,4 @@ class RBFKernel : public GramMatrixBase<math_t> {
}
};
}; // end namespace raft::distance::kernels::detail
}; // end namespace raft::distance::kernels
31 changes: 0 additions & 31 deletions cpp/include/raft/distance/specializations/detail/kernels.cuh

This file was deleted.

2 changes: 1 addition & 1 deletion cpp/include/raft/distance/specializations/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <raft/distance/specializations/detail/hellinger_expanded.cuh>
#include <raft/distance/specializations/detail/inner_product.cuh>
#include <raft/distance/specializations/detail/jensen_shannon.cuh>
#include <raft/distance/specializations/detail/kernels.cuh>
#include <raft/distance/specializations/detail/kl_divergence.cuh>
#include <raft/distance/specializations/detail/l1.cuh>
#include <raft/distance/specializations/detail/l2_expanded.cuh>
Expand All @@ -32,3 +31,4 @@
#include <raft/distance/specializations/detail/lp_unexpanded.cuh>
#include <raft/distance/specializations/detail/russel_rao.cuh>
#include <raft/distance/specializations/fused_l2_nn_min.cuh>
#include <raft/distance/specializations/kernels.cuh>
30 changes: 30 additions & 0 deletions cpp/include/raft/distance/specializations/kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/distance/kernels/gram_matrix.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>

extern template class raft::distance::kernels::GramMatrixBase<double>;
extern template class raft::distance::kernels::GramMatrixBase<float>;

extern template class raft::distance::kernels::PolynomialKernel<double, int>;
extern template class raft::distance::kernels::PolynomialKernel<float, int>;

extern template class raft::distance::kernels::TanhKernel<double>;
extern template class raft::distance::kernels::TanhKernel<float>;

extern template class raft::distance::kernels::RBFKernel<double>;
extern template class raft::distance::kernels::RBFKernel<float>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/gram_matrix.cuh>
#include <raft/distance/kernels/gram_matrix.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::GramMatrixBase<double>;
template class raft::distance::kernels::GramMatrixBase<double>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/gram_matrix.cuh>
#include <raft/distance/kernels/gram_matrix.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::GramMatrixBase<float>;
template class raft::distance::kernels::GramMatrixBase<float>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::PolynomialKernel<double, int>;
template class raft::distance::kernels::PolynomialKernel<double, int>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::PolynomialKernel<float, int>;
template class raft::distance::kernels::PolynomialKernel<float, int>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::RBFKernel<double>;
template class raft::distance::kernels::RBFKernel<double>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::RBFKernel<float>;
template class raft::distance::kernels::RBFKernel<float>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::TanhKernel<double>;
template class raft::distance::kernels::TanhKernel<double>;
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <raft/distance/detail/kernels/kernel_matrices.cuh>
#include <raft/distance/kernels/kernel_matrices.cuh>
#include <raft/distance/specializations.cuh>

template class raft::distance::kernels::detail::TanhKernel<float>;
template class raft::distance::kernels::TanhKernel<float>;

0 comments on commit 2b6090a

Please sign in to comment.