From e9959acd3600b4e1e75b1800f5bb400c7d6d71f5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 3 Oct 2022 13:54:44 -0400 Subject: [PATCH] Fixing a few compile errors in new APIs (#874) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/874 --- cpp/include/raft/cluster/single_linkage_types.hpp | 4 ++-- cpp/include/raft/linalg/axpy.cuh | 7 +++++-- cpp/include/raft/linalg/mean_squared_error.cuh | 8 ++++---- cpp/include/raft/sparse/hierarchy/common.h | 4 ++-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index d97e6afed3..79f2ede482 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -53,9 +53,9 @@ class linkage_output { } }; -class linkage_output_int_float : public linkage_output { +class linkage_output_int : public linkage_output { }; -class linkage_output__int64_float : public linkage_output { +class linkage_output_int64 : public linkage_output { }; }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 6d54f87e91..96cf4277f4 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -70,6 +70,7 @@ void axpy(const raft::handle_t& handle, */ template , typename = raft::enable_if_output_device_mdspan> void axpy(const raft::handle_t& handle, @@ -79,7 +80,7 @@ void axpy(const raft::handle_t& handle, const int incx, const int incy) { - RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); axpy(handle, y.size(), @@ -105,6 +106,8 @@ void axpy(const raft::handle_t& handle, * @param [in] incy stride between consecutive elements of y */ template , typename = raft::enable_if_output_device_mdspan> void axpy(const raft::handle_t& handle, @@ -114,7 +117,7 @@ void axpy(const raft::handle_t& handle, const int incx, const int incy) { - RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); axpy(handle, y.size(), diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index ddfe58dad7..582bab2acc 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -34,9 +34,9 @@ namespace linalg { * @param weight weight to apply to every term in the mean squared error calculation * @param stream cuda-stream where to launch this kernel */ -template +template void meanSquaredError( - math_t* out, const math_t* A, const math_t* B, size_t len, math_t weight, cudaStream_t stream) + out_t* out, const in_t* A, const in_t* B, size_t len, in_t weight, cudaStream_t stream) { detail::meanSquaredError(out, A, B, len, weight, stream); } @@ -58,7 +58,7 @@ void meanSquaredError( * @param[out] out the output mean squared error value of type raft::device_scalar_view * @param[in] weight weight to apply to every term in the mean squared error calculation */ -template +template void mean_squared_error(const raft::handle_t& handle, raft::device_vector_view A, raft::device_vector_view B, @@ -68,7 +68,7 @@ void mean_squared_error(const raft::handle_t& handle, RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); meanSquaredError( - out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, stream); + out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, handle.get_stream()); } /** @} */ // end of group mean_squared_error diff --git a/cpp/include/raft/sparse/hierarchy/common.h b/cpp/include/raft/sparse/hierarchy/common.h index 5440ae4ae6..3c3b92b739 100644 --- a/cpp/include/raft/sparse/hierarchy/common.h +++ b/cpp/include/raft/sparse/hierarchy/common.h @@ -28,7 +28,7 @@ namespace raft::hierarchy { using raft::cluster::linkage_output; -using raft::cluster::linkage_output__int64_float; -using raft::cluster::linkage_output_int_float; +using raft::cluster::linkage_output_int; +using raft::cluster::linkage_output_int64; using raft::cluster::LinkageDistance; } // namespace raft::hierarchy \ No newline at end of file