diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 0a5a3ba5aa..5328fbf35f 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -102,7 +103,7 @@ inline std::enable_if_t> predict_core( auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, make_extents(n_rows)); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(resource::get_thrust_policy(handle), + thrust::fill(resource::get_thrust_nosync_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), initial_value); @@ -128,7 +129,7 @@ inline std::enable_if_t> predict_core( // todo(lsugy): use KVP + iterator in caller. // Copy keys to output labels - thrust::transform(resource::get_thrust_policy(handle), + thrust::transform(resource::get_thrust_nosync_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + n_rows, labels, diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index d2021728c4..c849bfcebb 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -42,6 +42,7 @@ enum resource_type { STREAM_VIEW, // view of a cuda stream or a placeholder in // CUDA-free builds THRUST_POLICY, // thrust execution policy + THRUST_NOSYNC_POLICY, // thrust nosync execution policy WORKSPACE_RESOURCE, // rmm device memory resource CUBLASLT_HANDLE, // cublasLt handle CUSTOM, // runtime-shared default-constructible resource diff --git a/cpp/include/raft/core/resource/thrust_nosync_policy.hpp b/cpp/include/raft/core/resource/thrust_nosync_policy.hpp new file mode 100644 index 0000000000..77d4a7ee1a --- /dev/null +++ b/cpp/include/raft/core/resource/thrust_nosync_policy.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include +namespace raft::resource { +class thrust_nosync_policy_resource : public resource { + public: + thrust_nosync_policy_resource(rmm::cuda_stream_view stream_view) + : thrust_nosync_policy_(std::make_unique(stream_view)) + { + } + void* get_resource() override { return thrust_nosync_policy_.get(); } + + ~thrust_nosync_policy_resource() override {} + + private: + std::unique_ptr thrust_nosync_policy_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class thrust_nosync_policy_resource_factory : public resource_factory { + public: + thrust_nosync_policy_resource_factory(rmm::cuda_stream_view stream_view) + : stream_view_(stream_view) + { + } + resource_type get_resource_type() override { return resource_type::THRUST_NOSYNC_POLICY; } + resource* make_resource() override { return new thrust_nosync_policy_resource(stream_view_); } + + private: + rmm::cuda_stream_view stream_view_; +}; + +/** + * @defgroup resource_thrust_nosync_policy Thrust nosync policy resource functions + * @{ + */ + +/** + * Load a thrust nosync policy from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return thrust execution policy nosync + */ +inline rmm::exec_policy_nosync& get_thrust_nosync_policy(resources const& res) +{ + if (!res.has_resource_factory(resource_type::THRUST_NOSYNC_POLICY)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::THRUST_NOSYNC_POLICY); +}; + +/** + * @} + */ + +} // namespace raft::resource