diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index ba7ed3dcdf..561187178c 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -95,16 +95,19 @@ void eigDC(raft::resources const& handle, return; #endif -#if CUDART_VERSION <= 12040 - // Use a new stream instead of `cudaStreamPerThread` to avoid cusolver bug # 4580093. + int cudart_version = 0; + RAFT_CUDA_TRY(cudaRuntimeGetVersion(&cudart_version)); + cudaStream_t stream_new; + cudaEvent_t sync_event = resource::detail::get_cuda_stream_sync_event(handle); rmm::cuda_stream stream_new_wrapper; - cudaStream_t stream_new = stream_new_wrapper.value(); - cudaEvent_t sync_event = resource::detail::get_cuda_stream_sync_event(handle); - RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream)); - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_new, sync_event)); -#else - cudaStream_t stream_new = stream; -#endif + if (cudart_version < 12050) { + // Use a new stream instead of `cudaStreamPerThread` to avoid cusolver bug # 4580093. + stream_new = stream_new_wrapper.value(); + RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream)); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_new, sync_event)); + } else { + stream_new = stream; + } cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); cusolverDnParams_t dn_params = nullptr; @@ -152,11 +155,11 @@ void eigDC(raft::resources const& handle, "eig.cuh: eigensolver couldn't converge to a solution. " "This usually occurs when some of the features do not vary enough."); -#if CUDART_VERSION <= 12040 - // Synchronize the created stream with the original stream before return - RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream_new)); - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_event)); -#endif + if (cudart_version < 12050) { + // Synchronize the created stream with the original stream before return + RAFT_CUDA_TRY(cudaEventRecord(sync_event, stream_new)); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_event)); + } } enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT };