diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 4201ef3c473f7..de55c5399c4cb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -53,10 +53,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(4 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(4 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(13 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f06b42cbcd6e4..37aa6a7cd87b0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -30,7 +30,7 @@ * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 14 +#define ORT_API_VERSION 15 #ifdef __cplusplus extern "C" { @@ -3544,6 +3544,13 @@ struct OrtApi { ORT_CLASS_RELEASE(KernelInfo); /* \brief: Get the training C Api + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtTrainingApi for the version requested. + * nullptr will be returned and no error message will be printed if the training api is not supported with + * this build. + * nullptr will be returned and an error message will be printed if the provided version is unsupported, for + * example when using a runtime older than the version created with this header file. * * \since Version 1.13 */ @@ -3890,7 +3897,7 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.14. + * \since Version 1.15. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); @@ -3901,7 +3908,7 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.14. + * \since Version 1.15. */ ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out); @@ -3919,7 +3926,7 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.14. + * \since Version 1.15. */ ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options, _In_reads_(num_keys) const char* const* provider_options_keys, @@ -3938,13 +3945,13 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.14. + * \since Version 1.15. */ ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); /** \brief Release an ::OrtDnnlProviderOptions * - * \since Version 1.14. + * \since Version 1.15. */ void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2085d675ec438..c8e1fad05a90b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2351,13 +2351,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS - return OrtTrainingApis::GetTrainingApi(version); + if (version >= 13 && version <= ORT_API_VERSION) + return OrtTrainingApis::GetTrainingApi(version); + + fprintf(stderr, "The given version [%u] is not supported. Training api only supports version 13 to %u.\n", + version, ORT_API_VERSION); + return nullptr; #else ORT_UNUSED_PARAMETER(version); - fprintf(stderr, - "Training APIs are not supported with this build. Please build onnxruntime " - "from source with the build flags enable_training_apis to retrieve the training APIs.\n"); return nullptr; #endif @@ -2407,7 +2409,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_14 = { +static constexpr OrtApi ort_api_1_to_15 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -2731,7 +2733,7 @@ static_assert(std::string_view(ORT_VERSION) == "1.15.0", ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_14; + return &ort_api_1_to_15; fprintf(stderr, "The given version [%u] is not supported, only version 1 to %u is supported in this build.\n", version, ORT_API_VERSION); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 90fc3370e3c04..3debc856daec2 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -443,4 +443,3 @@ ORT_API_STATUS_IMPL(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOp ORT_API(void, ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProviderOptions*); } // namespace OrtApis -