Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide CPU/GPU interoperable device_id type #991

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2a63aff
Provide memory_type enum
wphicks Nov 3, 2022
5760977
Update style
wphicks Nov 3, 2022
bf4d172
Provide an enum for indicating compute device type
wphicks Nov 3, 2022
241f353
Reuse accessibility functions for host/device check
wphicks Nov 3, 2022
27a2d84
Merge branch 'fea-memory_types' into fea-device_type
wphicks Nov 3, 2022
14d97bf
Add missing arguments to refactor
wphicks Nov 3, 2022
1c15377
Merge branch 'fea-memory_types' into fea-device_type
wphicks Nov 3, 2022
777c006
Provide constants and infrastucture for CUDA-free usage
wphicks Nov 3, 2022
481f5de
Add missing copyright to header
wphicks Nov 3, 2022
15a7965
Switch to RAFT_DISABLE_CUDA identifier
wphicks Nov 3, 2022
68310ae
Merge branch 'fea-cuda_free' into fea-device_id
wphicks Nov 4, 2022
40fadb5
Provide device_id implementation
wphicks Nov 4, 2022
3aae4ce
Update compile-time identifier in CMakeLists.txt
wphicks Nov 4, 2022
920cbb8
Derive cuda_unsupported from raft::exception
wphicks Nov 4, 2022
57f6eed
Merge branch 'fea-cuda_free' into fea-device_id
wphicks Nov 7, 2022
2304b83
Correct error handling
wphicks Nov 7, 2022
5c096a6
Add missing constructor for exception
wphicks Nov 7, 2022
1663fe3
Merge branch 'fea-cuda_free' into fea-device_id
wphicks Nov 7, 2022
9d0ae21
Update style
wphicks Nov 7, 2022
797d6dc
Merge branch 'branch-23.04' into fea-device_type
wphicks Mar 6, 2023
0fec443
Move test file for consistency with new location
wphicks Mar 6, 2023
1530867
Merge branch 'fea-device_type' into fea-device_id
wphicks Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ option(CUDA_STATIC_RUNTIME "Statically link the CUDA toolkit runtime and librari
option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies" ON)
option(DISABLE_DEPRECATION_WARNINGS "Disable deprecaction warnings " ON)
option(DISABLE_OPENMP "Disable OpenMP" OFF)
option(DISABLE_CUDA "Disable CUDA in supported RAFT code" OFF)
option(RAFT_NVTX "Enable nvtx markers" OFF)

set(RAFT_COMPILE_LIBRARIES_DEFAULT OFF)
Expand Down Expand Up @@ -276,6 +277,13 @@ target_compile_definitions(raft::raft INTERFACE $<$<BOOL:${RAFT_NVTX}>:NVTX_ENAB
)
endif()

##############################################################################
# - CUDA-free build support --------------------------------------------------

if (DISABLE_CUDA)
target_compile_definitions(raft INTERFACE RAFT_DISABLE_CUDA)
endif()

# ##################################################################################################
# * raft_distance ------------------------------------------------------------ TODO: Currently, this
# package also contains the 'random' namespace (for rmat logic) We couldn't get this to work
Expand Down
37 changes: 37 additions & 0 deletions cpp/include/raft/core/detail/device_id_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/device_support.hpp>
#include <raft/core/device_type.hpp>

namespace raft {
namespace detail {
template <device_type D>
struct device_id {
using value_type = int;

device_id(value_type device_index = value_type{}) noexcept {}
auto value() const noexcept(false)
{
throw cuda_unsupported{"Attempting to use a GPU device in a non-CUDA build"};
}
auto rmm_id() const noexcept(false)
{
throw cuda_unsupported{"Attempting to use a GPU device in a non-CUDA build"};
}
};
} // namespace detail
} // namespace raft
34 changes: 34 additions & 0 deletions cpp/include/raft/core/detail/device_id_cpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/detail/device_id_base.hpp>
#include <raft/core/device_type.hpp>

namespace raft {
namespace detail {
template <>
struct device_id<device_type::cpu> {
using value_type = int;
device_id(value_type dev_id = value_type{}) noexcept : id_{dev_id} {}

auto value() const noexcept { return id_; }
auto rmm_id() const noexcept(false) { throw bad_device_type{"CPU devices have no RMM ID"}; }

private:
value_type id_;
};
} // namespace detail
} // namespace raft
50 changes: 50 additions & 0 deletions cpp/include/raft/core/detail/device_id_gpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/detail/device_id_base.hpp>
#include <raft/core/device_type.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/cuda_device.hpp>

namespace raft {
namespace detail {
template <>
struct device_id<device_type::gpu> {
using value_type = typename rmm::cuda_device_id::value_type;
device_id() noexcept(false)
: id_{[]() {
auto raw_id = value_type{};
RAFT_CUDA_TRY(cudaGetDevice(&raw_id));
return raw_id;
}()} {};
/* We do not mark this constructor as explicit to allow public API
* functions to accept `device_id` arguments without requiring
* downstream consumers to explicitly construct a device_id. Thus,
* consumers can use the type they expect to use when specifying a device
* (int), but once we are inside the public API, the device type remains
* attached to this value and we can easily convert to the strongly-typed
* rmm::cuda_device_id if desired.
*/
device_id(value_type dev_id) noexcept : id_{dev_id} {};

auto value() const noexcept { return id_.value(); }
auto rmm_id() const noexcept { return id_; }

private:
rmm::cuda_device_id id_;
};
} // namespace detail
} // namespace raft
31 changes: 31 additions & 0 deletions cpp/include/raft/core/device_id.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/detail/device_id_base.hpp>
#include <raft/core/detail/device_id_cpu.hpp>
#ifndef RAFT_DISABLE_CUDA
#include <raft/core/detail/device_id_gpu.hpp>
#endif
#include <raft/core/device_type.hpp>
#include <variant>

namespace raft {
template <device_type D>
using device_id = detail::device_id<D>;

using device_id_variant = std::variant<device_id<device_type::cpu>, device_id<device_type::gpu>>;
} // namespace raft
32 changes: 32 additions & 0 deletions cpp/include/raft/core/device_support.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/error.hpp>
#include <string>

namespace raft {
#ifdef RAFT_DISABLE_CUDA
auto constexpr static const CUDA_ENABLED = false;
#else
auto constexpr static const CUDA_ENABLED = true;
#endif

struct cuda_unsupported : raft::exception {
explicit cuda_unsupported(std::string const& msg) : raft::exception{msg} {}
cuda_unsupported() : cuda_unsupported{"CUDA functionality invoked in non-CUDA build"} {}
};

} // namespace raft
35 changes: 35 additions & 0 deletions cpp/include/raft/core/device_type.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/error.hpp>
#include <raft/core/memory_type.hpp>
#include <string>

namespace raft {
enum class device_type { cpu, gpu };

auto constexpr is_compatible(device_type dev_type, memory_type mem_type)
{
return (dev_type == device_type::gpu && is_device_accessible(mem_type)) ||
(dev_type == device_type::cpu && is_host_accessible(mem_type));
}

struct bad_device_type : raft::exception {
bad_device_type(std::string const& msg) : raft::exception{msg} {}
bad_device_type() : bad_device_type{"Incorrect device type for this operation"} {}
};

} // end namespace raft
2 changes: 2 additions & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ if(BUILD_TESTS)
CORE_TEST
PATH
test/core/logger.cpp
test/core/device_id.cpp
test/core/device_type.cpp
test/core/math_device.cu
test/core/math_host.cpp
test/core/operators_device.cu
Expand Down
39 changes: 39 additions & 0 deletions cpp/test/core/device_id.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <raft/core/device_id.hpp>
#include <raft/core/device_support.hpp>
#include <raft/core/device_type.hpp>

namespace raft {
TEST(DeviceID, CPU)
{
auto dev_id = device_id<device_type::cpu>{};
ASSERT_EQ(dev_id.value(), 0);
ASSERT_THROW(dev_id.rmm_id(), bad_device_type);
}

TEST(DeviceID, GPU)
{
auto dev_id = device_id<device_type::gpu>{};
#ifdef RAFT_DISABLE_CUDA
ASSERT_THROW(dev_id.rmm_id(), cuda_unsupported);
ASSERT_THROW(dev_id.value(), cuda_unsupported);
#else
ASSERT_EQ(dev_id.value(), dev_id.rmm_id().value());
#endif
}
} // namespace raft
35 changes: 35 additions & 0 deletions cpp/test/core/device_type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <raft/core/device_type.hpp>

namespace raft {
TEST(DeviceType, CPU)
{
static_assert(is_compatible(device_type::cpu, memory_type::host));
static_assert(!is_compatible(device_type::cpu, memory_type::device));
static_assert(is_compatible(device_type::cpu, memory_type::managed));
static_assert(is_compatible(device_type::cpu, memory_type::pinned));
}

TEST(DeviceType, GPU)
{
static_assert(!is_compatible(device_type::gpu, memory_type::host));
static_assert(is_compatible(device_type::gpu, memory_type::device));
static_assert(is_compatible(device_type::gpu, memory_type::managed));
static_assert(!is_compatible(device_type::gpu, memory_type::pinned));
}
} // namespace raft