Skip to content

Commit

Permalink
StreamsByThread
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Apr 29, 2024
1 parent 6908b36 commit 439a304
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
60 changes: 59 additions & 1 deletion cpp/include/kvikio/posix_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#include <unistd.h>
#include <cstddef>
#include <cstdlib>
#include <map>
#include <mutex>
#include <stack>
#include <thread>

#include <cstring>
#include <kvikio/error.hpp>
Expand All @@ -33,7 +35,60 @@ inline constexpr std::size_t posix_bounce_buffer_size = 2 << 23; // 16 MiB
namespace detail {

/**
* @brief Class to retain host memory allocations
* @brief Singleton class to retrieve a CUDA stream for device-host copying
*
* Call `AllocRetain::get` to get the CUDA stream assigned to the current
* CUDA context and thread.
*/
class StreamsByThread {
private:
std::map<std::pair<CUcontext, std::thread::id>, CUstream> _streams;

public:
StreamsByThread() = default;
~StreamsByThread() noexcept
{
for (auto& [_, stream] : _streams) {
try {
CUDA_DRIVER_TRY(cudaAPI::instance().StreamDestroy(stream));
} catch (const CUfileException& e) {
std::cerr << e.what() << std::endl;
}
}
}

static CUstream get(CUcontext ctx, std::thread::id thd_id)
{
static StreamsByThread _instance;

// It no current context, we return the null/default stream
if (ctx == nullptr) { return nullptr; }
auto key = std::make_pair(ctx, thd_id);

// Create new stream if `ctx` doesn't have one.
if (_instance._streams.find(key) == _instance._streams.end()) {
CUstream stream{};
CUDA_DRIVER_TRY(cudaAPI::instance().StreamCreate(&stream, CU_STREAM_DEFAULT));
_instance._streams[key] = stream;
}
return _instance._streams.at(key);
}

static CUstream get()
{
CUcontext ctx{nullptr};
CUDA_DRIVER_TRY(cudaAPI::instance().CtxGetCurrent(&ctx));
return get(ctx, std::this_thread::get_id());
}

StreamsByThread(const StreamsByThread&) = delete;
StreamsByThread& operator=(StreamsByThread const&) = delete;
StreamsByThread(StreamsByThread&& o) = delete;
StreamsByThread& operator=(StreamsByThread&& o) = delete;
};

/**
* @brief Singleton class to retain host memory allocations
*
* Call `AllocRetain::get` to get an allocation that will be retained when it
* goes out of scope (RAII). The size of all allocations are `posix_bounce_buffer_size`.
Expand Down Expand Up @@ -179,6 +234,9 @@ std::size_t posix_device_io(int fd,
off_t byte_remaining = convert_size2off(size);
const off_t chunk_size2 = convert_size2off(posix_bounce_buffer_size);

// Get a stream if none were given by the caller
if (stream == nullptr) { stream = StreamsByThread::get(); }

while (byte_remaining > 0) {
const off_t nbytes_requested = std::min(chunk_size2, byte_remaining);
ssize_t nbytes_got = nbytes_requested;
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/kvikio/shim/cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class cudaAPI {
decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
decltype(cuStreamCreate)* StreamCreate{nullptr};
decltype(cuStreamDestroy)* StreamDestroy{nullptr};

private:
cudaAPI()
Expand All @@ -72,6 +74,8 @@ class cudaAPI {
get_symbol(DevicePrimaryCtxRetain, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxRetain));
get_symbol(DevicePrimaryCtxRelease, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxRelease));
get_symbol(StreamSynchronize, lib, KVIKIO_STRINGIFY(cuStreamSynchronize));
get_symbol(StreamCreate, lib, KVIKIO_STRINGIFY(cuStreamCreate));
get_symbol(StreamDestroy, lib, KVIKIO_STRINGIFY(cuStreamDestroy));
}

public:
Expand Down

0 comments on commit 439a304

Please sign in to comment.