Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EP context for custom op #16454

Merged
merged 37 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
be791ad
implement CudaContext
RandyShuai Jun 16, 2023
c458fdd
separate header
RandyShuai Jun 21, 2023
ea4b673
Simplify get-resource API
RandyShuai Jun 26, 2023
7082502
deprecate obsoletes
RandyShuai Jun 26, 2023
10a8ef6
implement dml custom op
RandyShuai Jun 30, 2023
02936d7
implement dml stream
RandyShuai Jun 30, 2023
eaa9883
invoker dml identity op
RandyShuai Jul 10, 2023
fa684e6
refactor dml calls
RandyShuai Jul 10, 2023
704642f
fetch recorder
RandyShuai Jul 11, 2023
66d6409
tune stream handles
RandyShuai Jul 12, 2023
9375508
moving ep context
RandyShuai Jul 23, 2023
107f58d
resolve conflict
RandyShuai Jul 23, 2023
e36371c
polish headers
RandyShuai Aug 1, 2023
6ba4f5a
restructure parent class
RandyShuai Aug 1, 2023
a377d93
format code
RandyShuai Aug 1, 2023
34c13e9
format
RandyShuai Aug 2, 2023
4b52391
Add ROCM support (#17024)
RandySheriffH Aug 6, 2023
9aa6dbe
Merge branch 'main' into rashuai/EpContext
RandySheriffH Aug 7, 2023
4447f36
remove header
RandyShuai Aug 7, 2023
4c488b5
switch header
RandyShuai Aug 7, 2023
7626fe4
register all ops
RandyShuai Aug 7, 2023
51ce9d8
typo
RandyShuai Aug 7, 2023
22e658f
format
RandyShuai Aug 8, 2023
cdfa332
address comments
RandyShuai Aug 9, 2023
92de13a
fix build
RandyShuai Aug 9, 2023
ec27703
resolve conflict
RandyShuai Aug 9, 2023
cc17d1d
fix build
RandyShuai Aug 9, 2023
484dae2
move static (#17090)
RandySheriffH Aug 10, 2023
6ad9873
reset mem type
RandyShuai Aug 11, 2023
3148ff2
Merge branch 'rashuai/EpContext' of https://github.com/microsoft/onnx…
RandyShuai Aug 11, 2023
9bf07a4
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 11, 2023
469fc8a
return default
RandyShuai Aug 12, 2023
3f311cc
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 12, 2023
44af883
Merge branch 'main' into rashuai/EpContext
RandyShuai Aug 12, 2023
c07e76a
resolve conflict
RandyShuai Aug 14, 2023
002ef6f
revert dml changes
RandyShuai Aug 15, 2023
463f047
clean up dml macros
RandyShuai Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1453,14 +1453,24 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_USE_CUDA)
onnxruntime_add_shared_library(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include)
if (HAS_QSPECTRE)
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
endif()
elseif (onnxruntime_USE_ROCM)
onnxruntime_add_shared_library(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/rocm_ops.hip
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_options(custom_op_library PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1)
else()
onnxruntime_add_shared_library(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
endif()

if (onnxruntime_USE_DML)
target_include_directories(custom_op_library PRIVATE WIL::WIL)
target_link_libraries(custom_op_library PRIVATE dxguid.lib d3d12.lib dxgi.lib)
endif()

target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
target_link_libraries(custom_op_library PRIVATE ${GSL_TARGET})

Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/stream_handles.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class Stream {
}
}

virtual void* GetResource(int /*version*/, int /*id*/) const {
return nullptr;
}

private:
StreamHandle handle_;
const OrtDevice& device_;
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/providers/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include <core/session/onnxruntime_cxx_api.h>

struct Context {
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
Context() = default;
virtual ~Context(){};
virtual void Init(const OrtKernelContext&){};
};
50 changes: 50 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.
#pragma once

#include "cuda_resource.h"
#include "core/providers/context.h"
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cudnn.h>

#define ORT_CUDA_CTX

namespace Ort {

namespace Custom {

struct CudaContext : public Context {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cuda_stream = reinterpret_cast<cudaStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);
}
};

} // namespace Custom
} // namespace Ort
12 changes: 12 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
};
56 changes: 56 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#define ORT_DML_CTX

#include "dml_resource.h"
#include "core/providers/context.h"
#include <DirectML.h>
#include <d3d12.h>

namespace Ort {

namespace Custom {

struct DmlContext : public Context {
IDMLDevice* dml_device = {};
ID3D12Device* d3d12_device = {};
ID3D12GraphicsCommandList* cmd_list = {};
IDMLCommandRecorder* cmd_recorder = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_DML_RESOUCE_VERSION, DmlResource::dml_device_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch dml device", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
dml_device = reinterpret_cast<IDMLDevice*>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_DML_RESOUCE_VERSION, DmlResource::d3d12_device_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch dml d3d12 device", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
d3d12_device = reinterpret_cast<ID3D12Device*>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_DML_RESOUCE_VERSION, DmlResource::cmd_list_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch command list", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cmd_list = reinterpret_cast<ID3D12GraphicsCommandList*>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_DML_RESOUCE_VERSION, DmlResource::cmd_recorder_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch command recorder", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cmd_recorder = reinterpret_cast<IDMLCommandRecorder*>(resource);
}
};

} // namespace Custom
} // namespace Ort
13 changes: 13 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_DML_RESOUCE_VERSION 1

enum DmlResource : int {
dml_device_t = dml_resource_offset,
d3d12_device_t,
cmd_list_t,
cmd_recorder_t
};
9 changes: 9 additions & 0 deletions include/onnxruntime/core/providers/resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

enum ResourceOffset {
cpu_resource_offset = 0,
cuda_resource_offset = 10000,
dml_resource_offset = 20000,
rocm_resource_offset = 30000,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
};
48 changes: 48 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#define ORT_ROCM_CTX

#include "rocm_resource.h"
#include "core/providers/context.h"

namespace Ort {

namespace Custom {

struct RocmContext : public Context {

hipStream_t hip_stream = {};
miopenHandle_t miopen_handle = {};
rocblas_handle rblas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::hip_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
hip_stream = reinterpret_cast<hipStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::miopen_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
miopen_handle = reinterpret_cast<miopenHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_ROCM_RESOUCE_VERSION, RocmResource::rocblas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
rblas_handle = reinterpret_cast<rocblas_handle>(resource);
}
};

} // namespace Csutom
} // namespace Ort

13 changes: 13 additions & 0 deletions include/onnxruntime/core/providers/rocm/rocm_resource.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "core/providers/resource.h"

#define ORT_ROCM_RESOUCE_VERSION 1

enum RocmResource : int {
hip_stream_t = rocm_resource_offset,
miopen_handle_t,
rocblas_handle_t
};

15 changes: 13 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4308,8 +4308,6 @@ struct OrtApi {
*/
void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input);

/// @}

/** \brief Create an allocator with specific type and register it with the ::OrtEnv
* This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator
* Enables sharing the allocator between multiple sessions that use the same env instance.
Expand Down Expand Up @@ -4372,6 +4370,19 @@ struct OrtApi {
* \since Version 1.16.
*/
ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr);

/**
* Get a EP resoure.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
* \param resouce_version - Version of the resource
* \param resource_id - Type of resource
* \param resource - A pointer to returned resource
*
* \since Version 1.16.
*/
ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource);
};

/*
Expand Down
74 changes: 74 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,50 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
souptc marked this conversation as resolved.
Show resolved Hide resolved
std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

#ifdef ORT_CUDA_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local CudaContext cuda_context;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
cuda_context.Init(*context);
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#ifdef ORT_DML_CTX
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const DmlContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local DmlContext dml_context;
dml_context.Init(*context);
std::tuple<T> current = std::tuple<const DmlContext&>{dml_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#ifdef ORT_ROCM_CTX
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
thread_local RocmContext rocm_context;
rocm_context.Init(*context);
std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#endif

#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
Expand Down Expand Up @@ -437,6 +481,36 @@ struct OrtLiteCustomOp : public OrtCustomOp {
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}

#ifdef ORT_CUDA_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#ifdef ORT_DML_CTX
template <typename T, typename... Ts>
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const DmlContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#ifdef ORT_ROCM_CTX
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#endif

#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
Expand Down
Loading