Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-merge branch-24.02 to branch-24.04 #2114

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/include/raft/core/resource/cublas_handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,8 +60,8 @@ class cublas_resource_factory : public resource_factory {
*/

/**
* Load a cublasres_t from raft res if it exists, otherwise
* add it and return it.
* Load a `cublasHandle_t` from raft res if it exists, otherwise add it and return it.
*
* @param[in] res the raft resources object
* @return cublas handle
*/
Expand Down
68 changes: 68 additions & 0 deletions cpp/include/raft/core/resource/cublaslt_handle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cublasLt.h>
#include <raft/core/cublas_macros.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <memory>

namespace raft::resource {

class cublaslt_resource : public resource {
public:
cublaslt_resource() { RAFT_CUBLAS_TRY(cublasLtCreate(&handle_)); }
~cublaslt_resource() noexcept override { RAFT_CUBLAS_TRY_NO_THROW(cublasLtDestroy(handle_)); }
auto get_resource() -> void* override { return &handle_; }

private:
cublasLtHandle_t handle_;
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class cublaslt_resource_factory : public resource_factory {
public:
auto get_resource_type() -> resource_type override { return resource_type::CUBLASLT_HANDLE; }
auto make_resource() -> resource* override { return new cublaslt_resource(); }
};

/**
* @defgroup resource_cublaslt cuBLASLt handle resource functions
* @{
*/

/**
* Load a `cublasLtHandle_t` from raft res if it exists, otherwise add it and return it.
*
* @param[in] res the raft resources object
* @return cublasLt handle
*/
inline auto get_cublaslt_handle(resources const& res) -> cublasLtHandle_t
{
if (!res.has_resource_factory(resource_type::CUBLASLT_HANDLE)) {
res.add_resource_factory(std::make_shared<cublaslt_resource_factory>());
}
auto ret = *res.get_resource<cublasLtHandle_t>(resource_type::CUBLASLT_HANDLE);
return ret;
};

/**
* @}
*/

} // namespace raft::resource
93 changes: 93 additions & 0 deletions cpp/include/raft/core/resource/custom_resource.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <algorithm>
#include <memory>
#include <typeindex>

namespace raft::resource {

class custom_resource : public resource {
public:
custom_resource() = default;
~custom_resource() noexcept override = default;
auto get_resource() -> void* override { return this; }

template <typename ResourceT>
auto load() -> ResourceT*
{
std::lock_guard<std::mutex> _(lock_);
auto key = std::type_index{typeid(ResourceT)};
auto pos = std::lower_bound(store_.begin(), store_.end(), kv{key, {nullptr}});
if ((pos != store_.end()) && std::get<0>(*pos) == key) {
return reinterpret_cast<ResourceT*>(std::get<1>(*pos).get());
}
auto store_ptr = new ResourceT{};
store_.insert(pos, kv{key, std::shared_ptr<void>(store_ptr, [](void* ptr) {
delete reinterpret_cast<ResourceT*>(ptr);
})});
return store_ptr;
}

private:
using kv = std::tuple<std::type_index, std::shared_ptr<void>>;
std::mutex lock_{};
std::vector<kv> store_{};
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class custom_resource_factory : public resource_factory {
public:
auto get_resource_type() -> resource_type override { return resource_type::CUSTOM; }
auto make_resource() -> resource* override { return new custom_resource(); }
};

/**
* @defgroup resource_custom custom resource functions
* @{
*/

/**
* Get the custom default-constructible resource if it exists, create it otherwise.
*
* Note: in contrast to the other, hard-coded resources, there's no information about the custom
* resources at compile time. Hence, custom resources are kept in a hashmap and looked-up at
* runtime. This leads to slightly slower access times.
*
* @tparam ResourceT the type of the resource; it must be complete and default-constructible.
*
* @param[in] res the raft resources object
* @return a pointer to the custom resource.
*/
template <typename ResourceT>
auto get_custom_resource(resources const& res) -> ResourceT*
{
static_assert(std::is_default_constructible_v<ResourceT>);
if (!res.has_resource_factory(resource_type::CUSTOM)) {
res.add_resource_factory(std::make_shared<custom_resource_factory>());
}
return res.get_resource<custom_resource>(resource_type::CUSTOM)->load<ResourceT>();
};

/**
* @}
*/

} // namespace raft::resource
4 changes: 3 additions & 1 deletion cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,8 @@ enum resource_type {
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource

LAST_KEY // reserved for the last key
};
Expand Down
Loading
Loading