-
Notifications
You must be signed in to change notification settings - Fork 916
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding
hostdevice_span
that is a span createable from `hostdevice_v…
…ector` (#12981) I ran into a need for a span-like view into a `hostdevice_vector`. I was chopping it up into pieces to pass into a function to process portions at a time, but it still wanted to do things like host to device on the spans. This class is a result of that need. Authors: - Mike Wilson (https://github.com/hyperbolic2346) - Nghia Truong (https://github.com/ttnghia) Approvers: - Nghia Truong (https://github.com/ttnghia) - Vukasin Milovanovic (https://github.com/vuule) URL: #12981
- Loading branch information
1 parent
d82f97c
commit e28c9c5
Showing
4 changed files
with
340 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/utilities/span.hpp> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
|
||
template <typename T> | ||
class hostdevice_span { | ||
public: | ||
using value_type = T; | ||
|
||
hostdevice_span() = default; | ||
~hostdevice_span() = default; | ||
hostdevice_span(hostdevice_span const&) = default; ///< Copy constructor | ||
hostdevice_span(hostdevice_span&&) = default; ///< Move constructor | ||
|
||
hostdevice_span(T* cpu_data, T* gpu_data, size_t size) | ||
: _size(size), _host_data(cpu_data), _device_data(gpu_data) | ||
{ | ||
} | ||
|
||
/** | ||
* @brief Copy assignment operator. | ||
* | ||
* @return Reference to this hostdevice_span. | ||
*/ | ||
constexpr hostdevice_span& operator=(hostdevice_span const&) noexcept = default; | ||
|
||
/** | ||
* @brief Converts a hostdevice view into a device span. | ||
* | ||
* @tparam T The device span type. | ||
* @return A typed device span of the hostdevice view's data. | ||
*/ | ||
[[nodiscard]] operator cudf::device_span<T>() const | ||
{ | ||
return cudf::device_span(_device_data, size()); | ||
} | ||
|
||
/** | ||
* @brief Returns the underlying device data. | ||
* | ||
* @tparam T The type to cast to | ||
* @return T const* Typed pointer to underlying data | ||
*/ | ||
[[nodiscard]] T* device_ptr(size_t offset = 0) const noexcept { return _device_data + offset; } | ||
|
||
/** | ||
* @brief Return first element in device data. | ||
* | ||
* @tparam T The desired type | ||
* @return T const* Pointer to the first element | ||
*/ | ||
[[nodiscard]] T* device_begin() const noexcept { return device_ptr(); } | ||
|
||
/** | ||
* @brief Return one past the last element in device_data. | ||
* | ||
* @tparam T The desired type | ||
* @return T const* Pointer to one past the last element | ||
*/ | ||
[[nodiscard]] T* device_end() const noexcept { return device_begin() + size(); } | ||
|
||
/** | ||
* @brief Converts a hostdevice_span into a host span. | ||
* | ||
* @tparam T The host span type. | ||
* @return A typed host span of the hostdevice_span's data. | ||
*/ | ||
[[nodiscard]] operator cudf::host_span<T>() const noexcept | ||
{ | ||
return cudf::host_span<T>(_host_data, size()); | ||
} | ||
|
||
/** | ||
* @brief Returns the underlying host data. | ||
* | ||
* @tparam T The type to cast to | ||
* @return T* Typed pointer to underlying data | ||
*/ | ||
[[nodiscard]] T* host_ptr(size_t offset = 0) const noexcept { return _host_data + offset; } | ||
|
||
/** | ||
* @brief Return first element in host data. | ||
* | ||
* @tparam T The desired type | ||
* @return T const* Pointer to the first element | ||
*/ | ||
[[nodiscard]] T* host_begin() const noexcept { return host_ptr(); } | ||
|
||
/** | ||
* @brief Return one past the last elementin host data. | ||
* | ||
* @tparam T The desired type | ||
* @return T const* Pointer to one past the last element | ||
*/ | ||
[[nodiscard]] T* host_end() const noexcept { return host_begin() + size(); } | ||
|
||
/** | ||
* @brief Returns the number of elements in the view | ||
* | ||
* @return The number of elements in the view | ||
*/ | ||
[[nodiscard]] std::size_t size() const noexcept { return _size; } | ||
|
||
/** | ||
* @brief Returns true if `size()` returns zero, or false otherwise | ||
* | ||
* @return True if `size()` returns zero, or false otherwise | ||
*/ | ||
[[nodiscard]] bool is_empty() const noexcept { return size() == 0; } | ||
|
||
[[nodiscard]] size_t size_bytes() const noexcept { return sizeof(T) * size(); } | ||
|
||
[[nodiscard]] T& operator[](size_t i) { return _host_data[i]; } | ||
[[nodiscard]] T const& operator[](size_t i) const { return _host_data[i]; } | ||
|
||
/** | ||
* @brief Obtains a hostdevice_span that is a view over the `count` elements of this | ||
* hostdevice_span starting at offset | ||
* | ||
* @param offset The offset of the first element in the subspan | ||
* @param count The number of elements in the subspan | ||
* @return A subspan of the sequence, of requested count and offset | ||
*/ | ||
constexpr hostdevice_span<T> subspan(size_t offset, size_t count) const noexcept | ||
{ | ||
return hostdevice_span<T>(_host_data + offset, _device_data + offset, count); | ||
} | ||
|
||
void host_to_device(rmm::cuda_stream_view stream, bool synchronize = false) | ||
{ | ||
CUDF_CUDA_TRY( | ||
cudaMemcpyAsync(device_ptr(), host_ptr(), size_bytes(), cudaMemcpyDefault, stream.value())); | ||
if (synchronize) { stream.synchronize(); } | ||
} | ||
|
||
void device_to_host(rmm::cuda_stream_view stream, bool synchronize = false) | ||
{ | ||
CUDF_CUDA_TRY( | ||
cudaMemcpyAsync(host_ptr(), device_ptr(), size_bytes(), cudaMemcpyDefault, stream.value())); | ||
if (synchronize) { stream.synchronize(); } | ||
} | ||
|
||
private: | ||
size_t _size{}; ///< Number of elements | ||
T* _device_data{}; ///< Pointer to device memory containing elements | ||
T* _host_data{}; ///< Pointer to host memory containing elements | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.