From f545b7ebce3591920fdfe38e2215b38519240f53 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Sun, 17 Mar 2024 08:27:01 -0700 Subject: [PATCH 1/3] refactor ORT-Extension for the coming GroupQueryAttention work --- base/ortx_common.h | 117 +++++++++++++++++ base/string_utils.h | 2 +- docs/development.md | 2 +- includes/custom_op_lite.h | 18 +++ includes/onnxruntime_cpp_api_legacy.hpp | 2 +- includes/onnxruntime_customop.hpp | 118 ++--------------- includes/onnxruntime_no_customop.h | 122 ++++++++++++++++++ operators/contrib/cuda/cuda_type.h | 17 +++ operators/contrib/cuda/device_prop.cuh | 26 ++++ operators/contrib/cuda/fast_gelu.h | 12 +- operators/contrib/cuda/fast_gelu_impl.cu | 2 +- operators/contrib/cuda/utils.cuh | 19 +-- operators/tokenizer/bpe_decoder.hpp | 1 + operators/tokenizer/bpe_kernels.cc | 1 + operators/tokenizer/sentencepiece_decoder.hpp | 1 + .../tokenizer/sentencepiece_tokenizer.cc | 1 + operators/tokenizer/trie_tokenizer.hpp | 1 + 17 files changed, 327 insertions(+), 135 deletions(-) create mode 100644 base/ortx_common.h create mode 100644 includes/onnxruntime_no_customop.h create mode 100644 operators/contrib/cuda/cuda_type.h create mode 100644 operators/contrib/cuda/device_prop.cuh diff --git a/base/ortx_common.h b/base/ortx_common.h new file mode 100644 index 000000000..61cd7c0ce --- /dev/null +++ b/base/ortx_common.h @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include +#include "string_utils.h" +#ifdef _WIN32 +#include +#endif + +#define ORTX_RETURN_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +template +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { + if constexpr (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{std::string{str}}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +template +std::optional ParseEnvironmentVariable(const std::string& name) { + std::string buffer; +#ifdef _WIN32 + constexpr size_t kBufferSize = 32767; + + // Create buffer to hold the result + buffer.resize(kBufferSize, '\0'); + + // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. + // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. + // Therefore, If the function succeeds, kBufferSize should be larger than char_count. + auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize); + + if (kBufferSize > char_count) { + buffer.resize(char_count); + } else { + // Else either the call was failed, or the buffer wasn't large enough. + // TODO: Understand the reason for failure by calling GetLastError(). + // If it is due to the specified environment variable being found in the environment block, + // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. + // For now, we assume that the environment variable is not found. + buffer.clear(); + } +#else + char* val = getenv(name.c_str()); + buffer = (val == nullptr) ? std::string() : std::string(val); +#endif + T parsed_value; + if (!TryParseStringWithClassicLocale(buffer, parsed_value)) { + OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL); + } + return parsed_value; +} + +template +T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) { + const auto parsed = ParseEnvironmentVariable(name); + if (parsed.has_value()) { + return *parsed; + } + + return default_value; +} + +inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) { + if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true; + return false; +} diff --git a/base/string_utils.h b/base/string_utils.h index 5b2078b52..eaeec8687 100644 --- a/base/string_utils.h +++ b/base/string_utils.h @@ -3,7 +3,7 @@ #pragma once #include #include -#include "ocos.h" +#include "onnxruntime_cpp_api_legacy.hpp" template inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { diff --git a/docs/development.md b/docs/development.md index 167e7f27d..86e8610b9 100644 --- a/docs/development.md +++ b/docs/development.md @@ -17,7 +17,7 @@ The package contains all custom operators and some Python scripts to manipulate - no-opencv: disable operators based on OpenCV in build. - cc-debug: Generate debug info for extensions binaries and disable C/C++ compiler optimization. - For example:`pip install --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information. + For example:`pip install . --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information. Test: diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 365b020a7..e226f48ea 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -585,6 +585,11 @@ struct Variadic : public TensorBase { enum CudaResource { cuda_handle_t = 10000, + cudnn_handle_t, + cublas_handle_t, + deferred_cpu_allocator_t, + // below are cuda ep options + device_id_t, }; struct CudaContext { @@ -595,8 +600,21 @@ struct CudaContext { if (!cuda_stream) { ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); } + ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas); + if (!cublas) { + ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION); + } + void* resource = nullptr; + OrtStatusPtr result = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); + if (result) { + ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION); + } + memcpy(&device_id, &resource, sizeof(int)); } void* cuda_stream = {}; + void* cublas = {}; + int device_id = 0; + } }; #endif diff --git a/includes/onnxruntime_cpp_api_legacy.hpp b/includes/onnxruntime_cpp_api_legacy.hpp index 99f452282..f967b0b46 100644 --- a/includes/onnxruntime_cpp_api_legacy.hpp +++ b/includes/onnxruntime_cpp_api_legacy.hpp @@ -3,7 +3,7 @@ #pragma once #include -#include "onnxruntime_c_api.h" +#include "exceptions.h" // // DEPRECATED: All new custom OPs should not use any class/struct/functions from this file. diff --git a/includes/onnxruntime_customop.hpp b/includes/onnxruntime_customop.hpp index 5cf2a8dcb..6144338a2 100644 --- a/includes/onnxruntime_customop.hpp +++ b/includes/onnxruntime_customop.hpp @@ -15,118 +15,16 @@ #include #include #include +#include -#include "onnxruntime_c_api.h" #include "exceptions.h" +#include "onnxruntime_no_customop.h" #include "onnxruntime_cpp_api_legacy.hpp" #include "onnxruntime_extensions.h" #include "custom_op_lite.h" #define MIN_ORT_VERSION_SUPPORTED 11 -// namespace of ORT ABI Wrapper -namespace OrtW { - -class API { - // To use ONNX C ABI in a way like OrtW::API::CreateStatus. - public: - static API& instance(const OrtApi* ort_api = nullptr) noexcept { - static API self(ort_api); - return self; - } - - static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { - return instance()->CreateStatus(code, msg); - } - - static void ReleaseStatus(OrtStatusPtr ptr) noexcept { - instance()->ReleaseStatus(ptr); - } - - template - static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept; - - static void ThrowOnError(OrtStatusPtr ptr) { - OrtW::ThrowOnError(instance().api_, ptr); - } - - private: - const OrtApi* operator->() const { - return &api_; - } - - API(const OrtApi* api) : api_(*api) { - if (api == nullptr) { - ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION); - } - } - - const OrtApi& api_; -}; - -template <> -inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept { - return instance()->KernelInfoGetAttribute_int64(&info, name, &value); -} - -template <> -inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, float& value) noexcept { - return instance()->KernelInfoGetAttribute_float(&info, name, &value); -} - -template <> -inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, std::string& value) noexcept { - size_t size = 0; - std::string out; - // Feed nullptr for the data buffer to query the true size of the string attribute - OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size); - if (status == nullptr) { - out.resize(size); - status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size); - out.resize(size - 1); // remove the terminating character '\0' - } - - if (status == nullptr) { - value = std::move(out); - } - - return status; -} - -template -inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept { - if (auto status = API::KernelInfoGetAttribute(info, name, value); status) { - // Ideally, we should know which kind of error code can be ignored, but it is not available now. - // Just ignore all of them. - API::ReleaseStatus(status); - } - - return nullptr; -} - -inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) { - return API::CreateStatus(code, msg); -} - -inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) { - return API::CreateStatus(code, msg.c_str()); -} - -inline void ReleaseStatus(OrtStatusPtr& status) { - API::ReleaseStatus(status); - status = nullptr; -} - -} // namespace OrtW - -#define ORTX_RETURN_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if (_status != nullptr) { \ - return _status; \ - } \ - } while (0) - namespace Ort { namespace Custom { @@ -164,6 +62,12 @@ struct ComputeArgsList { using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const; }; +template +struct CustomOp_defined_getInputMemoryType : std::false_type {}; + +template +struct CustomOp_defined_getInputMemoryType> : std::true_type {}; + template struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { using ComputeFunction = decltype(&CustomOpKernel::Compute); @@ -236,6 +140,12 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = nullptr; OrtCustomOp::KernelCompute = nullptr; + if constexpr (CustomOp_defined_getInputMemoryType::value) { + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType { + return CustomOpKernel::GetInputMemoryType(index); + }; + } + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { if (api == nullptr) { diff --git a/includes/onnxruntime_no_customop.h b/includes/onnxruntime_no_customop.h new file mode 100644 index 000000000..a9477455a --- /dev/null +++ b/includes/onnxruntime_no_customop.h @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file defines API which depends on ONNXRuntime, but not including Custom Op and related facilities +// Custom Op and related classes, functions and macros are in onnxruntime_customop.hpp +#pragma once +#include "exceptions.h" + +// namespace of ORT ABI Wrapper +namespace OrtW { + +class API { + // To use ONNX C ABI in a way like OrtW::API::CreateStatus. + public: + static API& instance(const OrtApi* ort_api = nullptr) noexcept { + static API self(ort_api); + return self; + } + + static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { + return instance()->CreateStatus(code, msg); + } + + static void ReleaseStatus(OrtStatusPtr ptr) noexcept { + instance()->ReleaseStatus(ptr); + } + + template + static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept; + + static void ThrowOnError(OrtStatusPtr ptr) { + OrtW::ThrowOnError(instance().api_, ptr); + } + + // Caller is responsible for releasing OrtMemoryInfo object + static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept { + return instance()->CreateMemoryInfo(name, type, id, mem_type, out); + } + + // Caller is responsible for releasing OrtAllocator object: delete static_cast (allocator) + static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) { + return instance()->KernelContext_GetAllocator(context, mem_info, out); + } + + private: + const OrtApi* operator->() const { + return &api_; + } + + API(const OrtApi* api) : api_(*api) { + if (api == nullptr) { + ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION); + } + } + + const OrtApi& api_; +}; + +template <> +inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept { + return instance()->KernelInfoGetAttribute_int64(&info, name, &value); +} + +template <> +inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, float& value) noexcept { + return instance()->KernelInfoGetAttribute_float(&info, name, &value); +} + +template <> +inline OrtStatusPtr API::KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, std::string& value) noexcept { + size_t size = 0; + std::string out; + // Feed nullptr for the data buffer to query the true size of the string attribute + OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size); + if (status == nullptr) { + out.resize(size); + status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size); + out.resize(size - 1); // remove the terminating character '\0' + } + + if (status == nullptr) { + value = std::move(out); + } + + return status; +} + +template +inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept { + if (auto status = API::KernelInfoGetAttribute(info, name, value); status) { + // Ideally, we should know which kind of error code can be ignored, but it is not available now. + // Just ignore all of them. + API::ReleaseStatus(status); + } + + return nullptr; +} + +template +inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept { + T ret; + if (API::KernelInfoGetAttribute(info, name, ret)) { + ret = default_value; + } + return ret; +} + +inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) { + return API::CreateStatus(code, msg); +} + +inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) { + return API::CreateStatus(code, msg.c_str()); +} + +inline void ReleaseStatus(OrtStatusPtr& status) { + API::ReleaseStatus(status); + status = nullptr; +} + +} // namespace OrtW + diff --git a/operators/contrib/cuda/cuda_type.h b/operators/contrib/cuda/cuda_type.h new file mode 100644 index 000000000..525f4febe --- /dev/null +++ b/operators/contrib/cuda/cuda_type.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "onnxruntime_f16.h" +namespace contrib { +template +struct CudaT { + using MappedType = T; +}; + +template <> +struct CudaT { + using MappedType = half; +}; +} diff --git a/operators/contrib/cuda/device_prop.cuh b/operators/contrib/cuda/device_prop.cuh new file mode 100644 index 000000000..9f89d4b2a --- /dev/null +++ b/operators/contrib/cuda/device_prop.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include + +using namespace Ort::Custom; + +struct DeviceProp { + static cudaDeviceProp& GetCudaDeviceProp() { + static DeviceProp device_prop; + return device_prop.prop_; + } + static int GetCapability() { + return GetCudaDeviceProp().major; + } + + private: + DeviceProp() { + auto err = cudaGetDeviceProperties(&prop_, 0); + if (err != cudaError::cudaSuccess) { + throw std::runtime_error((std::string{"Failed to get device property, err code: "} + std::to_string(err)).c_str()); + } + } + cudaDeviceProp prop_; +}; + diff --git a/operators/contrib/cuda/fast_gelu.h b/operators/contrib/cuda/fast_gelu.h index db3ab1ccc..462595a08 100644 --- a/operators/contrib/cuda/fast_gelu.h +++ b/operators/contrib/cuda/fast_gelu.h @@ -4,20 +4,10 @@ #pragma once #include "ocos.h" #include "fast_gelu_impl.cuh" -#include "cuda_fp16.h" +#include "cuda_type.h" namespace contrib { -template -struct CudaT { - using MappedType = T; -}; - -template <> -struct CudaT { - using MappedType = half; -}; - template struct FastGelu { OrtStatusPtr OnModelAttach(const OrtApi& /*api*/, diff --git a/operators/contrib/cuda/fast_gelu_impl.cu b/operators/contrib/cuda/fast_gelu_impl.cu index 4c879fff6..d670076cf 100644 --- a/operators/contrib/cuda/fast_gelu_impl.cu +++ b/operators/contrib/cuda/fast_gelu_impl.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "utils.cuh" +#include "device_prop.cuh" #include "fast_gelu_impl.cuh" using namespace Ort::Custom; diff --git a/operators/contrib/cuda/utils.cuh b/operators/contrib/cuda/utils.cuh index 9da6aca93..a40bf8f39 100644 --- a/operators/contrib/cuda/utils.cuh +++ b/operators/contrib/cuda/utils.cuh @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "onnxruntime_f16.h" +#include "string_utils.h" +#include "onnxruntime_no_customop.h" #include #include #include @@ -189,19 +192,3 @@ __device__ __inline__ half2 _Tanh(half2 a) { template <> __device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast(a)); } - -struct DeviceProp { - static int GetCapability() { - static DeviceProp device_prop; - return device_prop.prop_.major; - } - - private: - DeviceProp() { - auto err = cudaGetDeviceProperties(&prop_, 0); - if (err != cudaError::cudaSuccess) { - throw std::runtime_error((std::string{"Failed to get device property, err code: "} + std::to_string(err)).c_str()); - } - } - cudaDeviceProp prop_; -}; \ No newline at end of file diff --git a/operators/tokenizer/bpe_decoder.hpp b/operators/tokenizer/bpe_decoder.hpp index 32f6f29df..a3c59336b 100644 --- a/operators/tokenizer/bpe_decoder.hpp +++ b/operators/tokenizer/bpe_decoder.hpp @@ -6,6 +6,7 @@ #include "ocos.h" #include "ustring.h" #include "narrow.h" +#include "ortx_common.h" #include #include #include diff --git a/operators/tokenizer/bpe_kernels.cc b/operators/tokenizer/bpe_kernels.cc index fefe7a040..cfca75d3f 100644 --- a/operators/tokenizer/bpe_kernels.cc +++ b/operators/tokenizer/bpe_kernels.cc @@ -3,6 +3,7 @@ #include "bpe_tokenizer.hpp" #include "bpe_kernels.h" +#include "ortx_common.h" #include diff --git a/operators/tokenizer/sentencepiece_decoder.hpp b/operators/tokenizer/sentencepiece_decoder.hpp index b7e64873d..e4eae8a3a 100644 --- a/operators/tokenizer/sentencepiece_decoder.hpp +++ b/operators/tokenizer/sentencepiece_decoder.hpp @@ -4,6 +4,7 @@ #pragma once #include "ocos.h" +#include "ortx_common.h" #include "string_utils.h" #include "string_tensor.h" #include "sentencepiece_processor.h" diff --git a/operators/tokenizer/sentencepiece_tokenizer.cc b/operators/tokenizer/sentencepiece_tokenizer.cc index e0cbbe679..46c9c45a7 100644 --- a/operators/tokenizer/sentencepiece_tokenizer.cc +++ b/operators/tokenizer/sentencepiece_tokenizer.cc @@ -8,6 +8,7 @@ #include "string_tensor.h" #include "base64.h" #include "narrow.h" +#include "ortx_common.h" OrtStatusPtr KernelSentencepieceTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { std::string model_as_string; diff --git a/operators/tokenizer/trie_tokenizer.hpp b/operators/tokenizer/trie_tokenizer.hpp index 21a8156ff..2ebf5ca4c 100644 --- a/operators/tokenizer/trie_tokenizer.hpp +++ b/operators/tokenizer/trie_tokenizer.hpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once #include "ocos.h" +#include "ortx_common.h" #include "narrow.h" #include From 05f5d252df7a19f0ad6953680dbf3edde073c998 Mon Sep 17 00:00:00 2001 From: jslhcl Date: Mon, 18 Mar 2024 11:30:14 -0700 Subject: [PATCH 2/3] fix typo and add #if ORT_API_VERSION >= 15 for GetOrtAllocator --- includes/custom_op_lite.h | 1 - includes/onnxruntime_no_customop.h | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index e226f48ea..784e2b2bd 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -614,7 +614,6 @@ struct CudaContext { void* cuda_stream = {}; void* cublas = {}; int device_id = 0; - } }; #endif diff --git a/includes/onnxruntime_no_customop.h b/includes/onnxruntime_no_customop.h index a9477455a..008980be4 100644 --- a/includes/onnxruntime_no_customop.h +++ b/includes/onnxruntime_no_customop.h @@ -36,12 +36,12 @@ class API { static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept { return instance()->CreateMemoryInfo(name, type, id, mem_type, out); } - +#if ORT_API_VERSION >= 15 // Caller is responsible for releasing OrtAllocator object: delete static_cast (allocator) static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) { return instance()->KernelContext_GetAllocator(context, mem_info, out); } - +#endif private: const OrtApi* operator->() const { return &api_; From 463924907af94a5f2480389ba37d968dcf7ffa8c Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 18 Mar 2024 13:06:11 -0700 Subject: [PATCH 3/3] fix cuda build --- operators/contrib/cuda/device_prop.cuh | 4 ++-- operators/contrib/cuda/fast_gelu_impl.cu | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/operators/contrib/cuda/device_prop.cuh b/operators/contrib/cuda/device_prop.cuh index 9f89d4b2a..d33895b2d 100644 --- a/operators/contrib/cuda/device_prop.cuh +++ b/operators/contrib/cuda/device_prop.cuh @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once +#include +#include #include -using namespace Ort::Custom; - struct DeviceProp { static cudaDeviceProp& GetCudaDeviceProp() { static DeviceProp device_prop; diff --git a/operators/contrib/cuda/fast_gelu_impl.cu b/operators/contrib/cuda/fast_gelu_impl.cu index d670076cf..7d0096b95 100644 --- a/operators/contrib/cuda/fast_gelu_impl.cu +++ b/operators/contrib/cuda/fast_gelu_impl.cu @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "device_prop.cuh" +#include "utils.cuh" #include "fast_gelu_impl.cuh" using namespace Ort::Custom;