Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow cuda custom ops allocate deferred cpu mem #17893

Merged
merged 7 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
Expand All @@ -44,6 +45,36 @@
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);

Check warning on line 50 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L50

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:50:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
}

void* AllocDeferredCpuMem(size_t size) const {
if (0 == size) {
return {};
}
const auto& ort_api = Ort::GetApi();
void* mem = {};
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
if (status) {
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
return mem;
}

void FreeDeferredCpuMem(void* mem) const {
if (mem) {
const auto& ort_api = Ort::GetApi();
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
if (status) {
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
}
}
};

Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1
#define ORT_CUDA_RESOUCE_VERSION 2

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
cublas_handle_t,
deferred_cpu_allocator_t,
};
25 changes: 24 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@

namespace onnxruntime {

DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) {
OrtAllocator::version = ORT_API_VERSION;
OrtAllocator::Alloc =
[](OrtAllocator* this_, size_t size) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
return self->cuda_stream_.GetCpuAllocator()->Alloc(size);
};
OrtAllocator::Free =
[](OrtAllocator* this_, void* p) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
self->cuda_stream_.EnqueDeferredCPUBuffer(p);
};
OrtAllocator::Info =
[](const OrtAllocator* this_) {
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
return &self->cuda_stream_.GetCpuAllocator()->Info();
};
}

struct CudaNotification : public synchronize::Notification {
CudaNotification(Stream& s) : Notification(s) {
CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
Expand Down Expand Up @@ -46,7 +65,8 @@
cublasHandle_t external_cublas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) {
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),

Check warning on line 68 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L68

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:68:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
deferred_cpu_allocator_(*this) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
Expand Down Expand Up @@ -162,6 +182,9 @@
case CudaResource::cublas_handle_t:
return reinterpret_cast<void*>(cublas_handle_);
break;
case CudaResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
default:
break;
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

namespace onnxruntime {

struct CudaStream;

struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(CudaStream&);

Check warning on line 15 in onnxruntime/core/providers/cuda/cuda_stream_handle.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.h#L15

Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.h:15:  Single-parameter constructors should be marked explicit.  [runtime/explicit] [5]
CudaStream& cuda_stream_;
};

struct CudaStream : Stream {
CudaStream(cudaStream_t stream,
const OrtDevice& device,
Expand Down Expand Up @@ -36,10 +43,13 @@

void* GetResource(int version, int id) const override;

onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
};

void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUDA
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)

#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
Expand Down Expand Up @@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
auto z_raw = Z.Allocate(input_shape);
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
}
Expand All @@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) {

} // namespace Cuda

#else

void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}

#endif
10 changes: 9 additions & 1 deletion onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

namespace Cuda {

#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)

void RegisterOps(Ort::CustomOpDomain& domain);

}
#else

void RegisterOps(Ort::CustomOpDomain&) {}

#endif

} // namespace Cuda

Check warning on line 18 in onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h#L18

Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
Raw output
onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h:18:  Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "cpu/cpu_ops.h"
#include "cuda/cuda_ops.h"
#include "rocm/rocm_ops.h"
#include "onnxruntime_lite_custom_op.h"

static const char* c_OpDomain = "test.customop";
Expand All @@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
Cpu::RegisterOps(domain);

Ort::CustomOpDomain domain_v2{"v2"};
Cpu::RegisterOps(domain_v2);

Cuda::RegisterOps(domain);
Cuda::RegisterOps(domain_v2);

Rocm::RegisterOps(domain);
Rocm::RegisterOps(domain_v2);

Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
session_options.Add(domain_v2);
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace Ort::Custom;
throw std::runtime_error(msg); \
}

namespace Cuda {
namespace Rocm {

void KernelOne(const Ort::Custom::RocmContext& rocm_ctx,
const Ort::Custom::Tensor<float>& X,
Expand All @@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
domain.Add(c_CustomOpOne.get());
}

} // namespace Cuda

#else

void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
} // namespace Rocm

#endif
10 changes: 9 additions & 1 deletion onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

namespace Rocm {

#ifdef USE_ROCM

void RegisterOps(Ort::CustomOpDomain& domain);

}
#else

inline void RegisterOps(Ort::CustomOpDomain&) {}

#endif

} // namespace Rocm
Loading