From 13326c2ea2b77b05d5290eebd5955cb657d3c9cb Mon Sep 17 00:00:00 2001 From: Mark Harris Date: Tue, 8 Jun 2021 11:35:02 +1000 Subject: [PATCH 1/2] Update Raft's get_rmm.cmake to better support CalVer --- cpp/cmake/thirdparty/get_rmm.cmake | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_rmm.cmake b/cpp/cmake/thirdparty/get_rmm.cmake index dbb4715736..51f959a8d9 100644 --- a/cpp/cmake/thirdparty/get_rmm.cmake +++ b/cpp/cmake/thirdparty/get_rmm.cmake @@ -20,13 +20,19 @@ function(find_and_configure_rmm VERSION) return() endif() + if(${VERSION} MATCHES [=[([0-9]+)\.([0-9]+)\.([0-9]+)]=]) + set(MAJOR_AND_MINOR "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}") + else() + set(MAJOR_AND_MINOR "${VERSION}") + endif() + rapids_cpm_find(rmm ${VERSION} GLOBAL_TARGETS rmm::rmm BUILD_EXPORT_SET raft-exports INSTALL_EXPORT_SET raft-exports CPM_ARGS GIT_REPOSITORY https://github.com/rapidsai/rmm.git - GIT_TAG branch-${VERSION} + GIT_TAG branch-${MAJOR_AND_MINOR} GIT_SHALLOW TRUE OPTIONS "BUILD_TESTS OFF" "BUILD_BENCHMARKS OFF" @@ -36,6 +42,6 @@ function(find_and_configure_rmm VERSION) endfunction() -set(RAFT_MIN_VERSION_rmm "${RAFT_VERSION_MAJOR}.${RAFT_VERSION_MINOR}") +set(RAFT_MIN_VERSION_rmm "${RAFT_VERSION_MAJOR}.${RAFT_VERSION_MINOR}.00") find_and_configure_rmm(${RAFT_MIN_VERSION_rmm}) From 57de1954b83eba70a49d4142f439520f28e34f61 Mon Sep 17 00:00:00 2001 From: Mark Harris Date: Tue, 8 Jun 2021 12:04:42 +1000 Subject: [PATCH 2/2] Pass stream to device_scalar::value() calls. --- cpp/include/raft/comms/test.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 2ba9e406be..4e95c4eef0 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -378,7 +378,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t &h, int numTrials) { communicator.sync_stream(stream); - if (!sender && received_data.value() != rank - 1) { + if (!sender && received_data.value(stream) != rank - 1) { ret = false; } @@ -424,8 +424,8 @@ bool test_pointToPoint_device_sendrecv(const handle_t &h, int numTrials) { communicator.sync_stream(stream); - if (((rank % 2 == 0) && (received_data.value() != rank + 1)) || - ((rank % 2 == 1) && (received_data.value() != rank - 1))) { + if (((rank % 2 == 0) && (received_data.value(stream) != rank + 1)) || + ((rank % 2 == 1) && (received_data.value(stream) != rank - 1))) { ret = false; }