diff --git a/cpp/cmake/thirdparty/get_rmm.cmake b/cpp/cmake/thirdparty/get_rmm.cmake index dbb4715736..51f959a8d9 100644 --- a/cpp/cmake/thirdparty/get_rmm.cmake +++ b/cpp/cmake/thirdparty/get_rmm.cmake @@ -20,13 +20,19 @@ function(find_and_configure_rmm VERSION) return() endif() + if(${VERSION} MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(MAJOR_AND_MINOR "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}") + else() + set(MAJOR_AND_MINOR "${VERSION}") + endif() + rapids_cpm_find(rmm ${VERSION} GLOBAL_TARGETS rmm::rmm BUILD_EXPORT_SET raft-exports INSTALL_EXPORT_SET raft-exports CPM_ARGS GIT_REPOSITORY https://github.com/rapidsai/rmm.git - GIT_TAG branch-${VERSION} + GIT_TAG branch-${MAJOR_AND_MINOR} GIT_SHALLOW TRUE OPTIONS "BUILD_TESTS OFF" "BUILD_BENCHMARKS OFF" @@ -36,6 +42,6 @@ function(find_and_configure_rmm VERSION) endfunction() -set(RAFT_MIN_VERSION_rmm "${RAFT_VERSION_MAJOR}.${RAFT_VERSION_MINOR}") +set(RAFT_MIN_VERSION_rmm "${RAFT_VERSION_MAJOR}.${RAFT_VERSION_MINOR}.00") find_and_configure_rmm(${RAFT_MIN_VERSION_rmm}) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 2ba9e406be..4e95c4eef0 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -378,7 +378,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t &h, int numTrials) { communicator.sync_stream(stream); - if (!sender && received_data.value() != rank - 1) { + if (!sender && received_data.value(stream) != rank - 1) { ret = false; } @@ -424,8 +424,8 @@ bool test_pointToPoint_device_sendrecv(const handle_t &h, int numTrials) { communicator.sync_stream(stream); - if (((rank % 2 == 0) && (received_data.value() != rank + 1)) || - ((rank % 2 == 1) && (received_data.value() != rank - 1))) { + if (((rank % 2 == 0) && (received_data.value(stream) != rank + 1)) || + ((rank % 2 == 1) && (received_data.value(stream) != rank - 1))) { ret = false; }