From 8256217c62ba9d6fcb2be52be7232d9f9415b9e4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 3 Oct 2022 11:14:33 -0400 Subject: [PATCH 1/3] Fixing compile errors on cuml side --- cpp/include/raft/cluster/single_linkage_types.hpp | 4 ++-- cpp/include/raft/linalg/axpy.cuh | 7 +++++-- cpp/include/raft/linalg/mean_squared_error.cuh | 8 ++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index d97e6afed3..503b5f5231 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -53,9 +53,9 @@ class linkage_output { } }; -class linkage_output_int_float : public linkage_output { +class linkage_output_int : public linkage_output { }; -class linkage_output__int64_float : public linkage_output { +class linkage_output__int64 : public linkage_output { }; }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 6d54f87e91..96cf4277f4 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -70,6 +70,7 @@ void axpy(const raft::handle_t& handle, */ template , typename = raft::enable_if_output_device_mdspan> void axpy(const raft::handle_t& handle, @@ -79,7 +80,7 @@ void axpy(const raft::handle_t& handle, const int incx, const int incy) { - RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); axpy(handle, y.size(), @@ -105,6 +106,8 @@ void axpy(const raft::handle_t& handle, * @param [in] incy stride between consecutive elements of y */ template , typename = raft::enable_if_output_device_mdspan> void axpy(const raft::handle_t& handle, @@ -114,7 +117,7 @@ void axpy(const raft::handle_t& handle, const int incx, const int incy) { - RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); axpy(handle, y.size(), diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index ddfe58dad7..582bab2acc 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -34,9 +34,9 @@ namespace linalg { * @param weight weight to apply to every term in the mean squared error calculation * @param stream cuda-stream where to launch this kernel */ -template +template void meanSquaredError( - math_t* out, const math_t* A, const math_t* B, size_t len, math_t weight, cudaStream_t stream) + out_t* out, const in_t* A, const in_t* B, size_t len, in_t weight, cudaStream_t stream) { detail::meanSquaredError(out, A, B, len, weight, stream); } @@ -58,7 +58,7 @@ void meanSquaredError( * @param[out] out the output mean squared error value of type raft::device_scalar_view * @param[in] weight weight to apply to every term in the mean squared error calculation */ -template +template void mean_squared_error(const raft::handle_t& handle, raft::device_vector_view A, raft::device_vector_view B, @@ -68,7 +68,7 @@ void mean_squared_error(const raft::handle_t& handle, RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); meanSquaredError( - out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, stream); + out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, handle.get_stream()); } /** @} */ // end of group mean_squared_error From 3df3398ef85fbf98ed1bcad1f718493edfa0eca0 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 3 Oct 2022 12:17:13 -0400 Subject: [PATCH 2/3] Fixing typo in linkage_output --- cpp/include/raft/cluster/single_linkage_types.hpp | 2 +- cpp/include/raft/sparse/hierarchy/common.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 503b5f5231..79f2ede482 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -55,7 +55,7 @@ class linkage_output { class linkage_output_int : public linkage_output { }; -class linkage_output__int64 : public linkage_output { +class linkage_output_int64 : public linkage_output { }; }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/include/raft/sparse/hierarchy/common.h b/cpp/include/raft/sparse/hierarchy/common.h index 5440ae4ae6..619e063fef 100644 --- a/cpp/include/raft/sparse/hierarchy/common.h +++ b/cpp/include/raft/sparse/hierarchy/common.h @@ -28,7 +28,7 @@ namespace raft::hierarchy { using raft::cluster::linkage_output; -using raft::cluster::linkage_output__int64_float; -using raft::cluster::linkage_output_int_float; +using raft::cluster::linkage_output_int64; +using raft::cluster::linkage_output_int; using raft::cluster::LinkageDistance; } // namespace raft::hierarchy \ No newline at end of file From 3cd7faae92479f1a3457e5fd883d1d6e48b9f45c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 3 Oct 2022 12:19:58 -0400 Subject: [PATCH 3/3] Fixing style --- cpp/include/raft/sparse/hierarchy/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/hierarchy/common.h b/cpp/include/raft/sparse/hierarchy/common.h index 619e063fef..3c3b92b739 100644 --- a/cpp/include/raft/sparse/hierarchy/common.h +++ b/cpp/include/raft/sparse/hierarchy/common.h @@ -28,7 +28,7 @@ namespace raft::hierarchy { using raft::cluster::linkage_output; -using raft::cluster::linkage_output_int64; using raft::cluster::linkage_output_int; +using raft::cluster::linkage_output_int64; using raft::cluster::LinkageDistance; } // namespace raft::hierarchy \ No newline at end of file