Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EP context for custom op #16454

Merged
merged 37 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
be791ad
implement CudaContext
RandyShuai Jun 16, 2023
c458fdd
separate header
RandyShuai Jun 21, 2023
ea4b673
Simplify get-resource API
RandyShuai Jun 26, 2023
7082502
deprecate obsoletes
RandyShuai Jun 26, 2023
10a8ef6
implement dml custom op
RandyShuai Jun 30, 2023
02936d7
implement dml stream
RandyShuai Jun 30, 2023
eaa9883
invoker dml identity op
RandyShuai Jul 10, 2023
fa684e6
refactor dml calls
RandyShuai Jul 10, 2023
704642f
fetch recorder
RandyShuai Jul 11, 2023
66d6409
tune stream handles
RandyShuai Jul 12, 2023
9375508
moving ep context
RandyShuai Jul 23, 2023
107f58d
resolve conflict
RandyShuai Jul 23, 2023
e36371c
polish headers
RandyShuai Aug 1, 2023
6ba4f5a
restructure parent class
RandyShuai Aug 1, 2023
a377d93
format code
RandyShuai Aug 1, 2023
34c13e9
format
RandyShuai Aug 2, 2023
4b52391
Add ROCM support (#17024)
RandySheriffH Aug 6, 2023
9aa6dbe
Merge branch 'main' into rashuai/EpContext
RandySheriffH Aug 7, 2023
4447f36
remove header
RandyShuai Aug 7, 2023
4c488b5
switch header
RandyShuai Aug 7, 2023
7626fe4
register all ops
RandyShuai Aug 7, 2023
51ce9d8
typo
RandyShuai Aug 7, 2023
22e658f
format
RandyShuai Aug 8, 2023
cdfa332
address comments
RandyShuai Aug 9, 2023
92de13a
fix build
RandyShuai Aug 9, 2023
ec27703
resolve conflict
RandyShuai Aug 9, 2023
cc17d1d
fix build
RandyShuai Aug 9, 2023
484dae2
move static (#17090)
RandySheriffH Aug 10, 2023
6ad9873
reset mem type
RandyShuai Aug 11, 2023
3148ff2
Merge branch 'rashuai/EpContext' of https://github.com/microsoft/onnx…
RandyShuai Aug 11, 2023
9bf07a4
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 11, 2023
469fc8a
return default
RandyShuai Aug 12, 2023
3f311cc
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 12, 2023
44af883
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 12, 2023
c07e76a
resolve conflict
RandyShuai Aug 14, 2023
002ef6f
revert dml changes
RandyShuai Aug 15, 2023
463f047
clean up dml macros
RandyShuai Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1450,19 +1450,40 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()

if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")

set(custom_op_src_patterns
"${TEST_SRC_DIR}/testdata/custom_op_library/*.h"
"${TEST_SRC_DIR}/testdata/custom_op_library/*.cc"
"${TEST_SRC_DIR}/testdata/custom_op_library/cpu/cpu_ops.*"
)

set(custom_op_lib_include ${REPO_ROOT}/include)
set(custom_op_lib_option)
set(custom_op_lib_link ${GSL_TARGET})

if (onnxruntime_USE_CUDA)
onnxruntime_add_shared_library(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu"
"${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*")
list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
if (HAS_QSPECTRE)
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
endif()
else()
onnxruntime_add_shared_library(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
endif()

target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
target_link_libraries(custom_op_library PRIVATE ${GSL_TARGET})
if (onnxruntime_USE_ROCM)
list(APPEND custom_op_src_patterns
"${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/rocm_ops.hip"
"${TEST_SRC_DIR}/testdata/custom_op_library/rocm/rocm_ops.*")
list(APPEND custom_op_lib_include ${onnxruntime_ROCM_HOME}/include)
list(APPEND custom_op_lib_option "-D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1")
endif()

file(GLOB custom_op_src ${custom_op_src_patterns})
onnxruntime_add_shared_library(custom_op_library ${custom_op_src})
target_compile_options(custom_op_library PRIVATE ${custom_op_lib_option})
target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include ${custom_op_lib_include})
target_link_libraries(custom_op_library PRIVATE ${GSL_TARGET} ${custom_op_lib_link})

if(UNIX)
if (APPLE)
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/stream_handles.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class Stream {
}
}

virtual void* GetResource(int /*version*/, int /*id*/) const {
return nullptr;
}

private:
StreamHandle handle_;
const OrtDevice& device_;
Expand Down
51 changes: 51 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.
#pragma once

#define ORT_CUDA_CTX

#include "cuda_resource.h"
#include "core/providers/custom_op_context.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cudnn.h>

namespace Ort {

namespace Custom {

struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cuda_stream = reinterpret_cast<cudaStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
}
};

} // namespace Custom
} // namespace Ort
12 changes: 12 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
};
13 changes: 13 additions & 0 deletions include/onnxruntime/core/providers/custom_op_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include <core/session/onnxruntime_cxx_api.h>

// CustomOpContext defines an interface allowing a custom op to access ep-specific resources.
struct CustomOpContext {
CustomOpContext() = default;
virtual ~CustomOpContext(){};
virtual void Init(const OrtKernelContext&){};
};
14 changes: 14 additions & 0 deletions include/onnxruntime/core/providers/resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

enum ResourceOffset {
cpu_resource_offset = 0,
cuda_resource_offset = 10000,
dml_resource_offset = 20000,
rocm_resource_offset = 30000,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
// offsets for other ort eps
custom_ep_resource_offset = 10000000,
// offsets for customized eps
};
49 changes: 49 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#define ORT_ROCM_CTX

#include "rocm_resource.h"
#include "core/providers/custom_op_context.h"
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>

namespace Ort {

namespace Custom {

struct RocmContext : public CustomOpContext {
hipStream_t hip_stream = {};
miopenHandle_t miopen_handle = {};
rocblas_handle rblas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
hip_stream = reinterpret_cast<hipStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
miopen_handle = reinterpret_cast<miopenHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
rblas_handle = reinterpret_cast<rocblas_handle>(resource);
}
};

} // namespace Custom
} // namespace Ort
12 changes: 12 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_ROCM_RESOUCE_VERSION 1

enum RocmResource : int {
hip_stream_t = rocm_resource_offset,
miopen_handle_t,
rocblas_handle_t
};
15 changes: 13 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4308,8 +4308,6 @@ struct OrtApi {
*/
void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input);

/// @}

/** \brief Create an allocator with specific type and register it with the ::OrtEnv
* This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator
* Enables sharing the allocator between multiple sessions that use the same env instance.
Expand Down Expand Up @@ -4398,6 +4396,19 @@ struct OrtApi {
* \since Version 1.16.
*/
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);

/**
* Get a EP resoure.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
* \param resouce_version - Version of the resource
* \param resource_id - Type of resource
* \param resource - A pointer to returned resource
*
* \since Version 1.16.
*/
ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource);
};

/*
Expand Down
54 changes: 54 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,38 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
souptc marked this conversation as resolved.
Show resolved Hide resolved
std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

#ifdef ORT_CUDA_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local CudaContext cuda_context;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
cuda_context.Init(*context);
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#ifdef ORT_ROCM_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local RocmContext rocm_context;
rocm_context.Init(*context);
std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
Expand Down Expand Up @@ -437,6 +469,28 @@ struct OrtLiteCustomOp : public OrtCustomOp {
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}

#ifdef ORT_CUDA_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#ifdef ORT_ROCM_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
Expand Down
21 changes: 20 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_resource.h"
#include "core/providers/cuda/cuda_stream_handle.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/common/spin_pause.h"
Expand Down Expand Up @@ -149,6 +149,25 @@ Status CudaStream::CleanUpOnRunEnd() {
return Status::OK();
}

void* CudaStream::GetResource(int version, int id) const {
ORT_ENFORCE(version <= ORT_CUDA_RESOUCE_VERSION, "resource version unsupported!");
void* resource{};
switch (id) {
case CudaResource::cuda_stream_t:
return reinterpret_cast<void*>(GetHandle());
break;
case CudaResource::cudnn_handle_t:
return reinterpret_cast<void*>(cudnn_handle_);
break;
case CudaResource::cublas_handle_t:
return reinterpret_cast<void*>(cublas_handle_);
break;
default:
break;
}
return resource;
}

// CPU Stream command handles
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
static_cast<CudaNotification*>(&notification)->wait_on_device(stream);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ struct CudaStream : Stream {

cublasHandle_t cublas_handle_{};

void* GetResource(int version, int id) const override;

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
Expand Down
21 changes: 20 additions & 1 deletion onnxruntime/core/providers/rocm/rocm_stream_handle.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "core/providers/rocm/rocm_stream_handle.h"
#include "core/providers/rocm/rocm_common.h"
// #include "core/common/spin_pause.h"
#include "core/providers/rocm/rocm_resource.h"

namespace onnxruntime {

Expand Down Expand Up @@ -129,7 +130,25 @@ Status RocmStream::CleanUpOnRunEnd() {
return Status::OK();
}

// CPU Stream command handles
void* RocmStream::GetResource(int version, int type) const {
ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!");
void* resource{};
switch (type) {
case RocmResource::hip_stream_t:
return reinterpret_cast<void*>(GetHandle());
break;
case RocmResource::miopen_handle_t:
return reinterpret_cast<void*>(miopen_handle_);
break;
case RocmResource::rocblas_handle_t:
return reinterpret_cast<void*>(rocblas_handle_);
break;
default:
break;
}
return resource;
}

void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) {
static_cast<RocmNotification*>(&notification)->wait_on_device(stream);
}
Expand Down
Loading