diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 2a4cfd52ec..ba7ed3dcdf 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -19,10 +19,12 @@ #include "cusolver_wrappers.hpp" #include +#include #include #include #include +#include #include #include @@ -90,7 +92,19 @@ void eigDC(raft::resources const& handle, { #if CUDART_VERSION < 11010 eigDC_legacy(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream); + return; +#endif + +#if CUDART_VERSION <= 12040 + // Use a new stream instead of `cudaStreamPerThread` to avoid cusolver bug # 4580093. + rmm::cuda_stream stream_new_wrapper; + cudaStream_t stream_new = stream_new_wrapper.value(); + cudaEvent_t sync_event = resource::detail::get_cuda_stream_sync_event(handle); + RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream)); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_new, sync_event)); #else + cudaStream_t stream_new = stream; +#endif cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); cusolverDnParams_t dn_params = nullptr; @@ -108,15 +122,13 @@ void eigDC(raft::resources const& handle, eig_vals, &workspaceDevice, &workspaceHost, - stream)); + stream_new)); - rmm::device_uvector d_work(workspaceDevice / sizeof(math_t), stream); - rmm::device_scalar d_dev_info(stream); + rmm::device_uvector d_work(workspaceDevice / sizeof(math_t), stream_new); + rmm::device_scalar d_dev_info(stream_new); std::vector h_work(workspaceHost / sizeof(math_t)); - raft::matrix::copy(handle, - make_device_matrix_view(in, n_rows, n_cols), - make_device_matrix_view(eig_vectors, n_rows, n_cols)); + raft::copy(eig_vectors, in, n_rows * n_cols, stream_new); RAFT_CUSOLVER_TRY(cusolverDnxsyevd(cusolverH, dn_params, @@ -131,14 +143,19 @@ void eigDC(raft::resources const& handle, h_work.data(), workspaceHost, d_dev_info.data(), - stream)); + stream_new)); RAFT_CUDA_TRY(cudaGetLastError()); RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params)); - int dev_info = d_dev_info.value(stream); + int dev_info = d_dev_info.value(stream_new); ASSERT(dev_info == 0, "eig.cuh: eigensolver couldn't converge to a solution. " "This usually occurs when some of the features do not vary enough."); + +#if CUDART_VERSION <= 12040 + // Synchronize the created stream with the original stream before return + RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream_new)); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_event)); #endif } diff --git a/cpp/test/linalg/eig.cu b/cpp/test/linalg/eig.cu index 460b99aaa0..3ff117cf08 100644 --- a/cpp/test/linalg/eig.cu +++ b/cpp/test/linalg/eig.cu @@ -156,6 +156,24 @@ class EigTest : public ::testing::TestWithParam> { eig_vals_large, eig_vals_jacobi_large; }; +TEST(Raft, EigStream) +{ + // Separate test to check eig_dc stream workaround for CUDA 12+ + raft::resources handle; + auto n_rows = 5000; + auto cov_matrix_stream = + raft::make_device_matrix(handle, n_rows, n_rows); + auto eig_vectors_stream = + raft::make_device_matrix(handle, n_rows, n_rows); + auto eig_vals_stream = raft::make_device_vector(handle, n_rows); + + raft::linalg::eig_dc(handle, + raft::make_const_mdspan(cov_matrix_stream.view()), + eig_vectors_stream.view(), + eig_vals_stream.view()); + raft::resource::sync_stream(handle, raft::resource::get_cuda_stream(handle)); +} + const std::vector> inputsf2 = {{0.001f, 4 * 4, 4, 4, 1234ULL, 256}}; const std::vector> inputsd2 = {{0.001, 4 * 4, 4, 4, 1234ULL, 256}};