Skip to content

Commit

Permalink
GetTrainingApi to not print to stderr when not an ort training build (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Feb 2, 2023
1 parent 68a402e commit 3d8fa4d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ static NativeTrainingMethods()
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));

// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(4 /*ORT_API_VERSION*/);
api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/);

OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(4 /*ORT_API_VERSION*/);
trainingApiPtr = OrtGetTrainingApi(13 /*ORT_API_VERSION*/);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
Expand Down
19 changes: 13 additions & 6 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
*
* This value is used by some API functions to behave as this version of the header expects.
*/
#define ORT_API_VERSION 14
#define ORT_API_VERSION 15

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -3544,6 +3544,13 @@ struct OrtApi {
ORT_CLASS_RELEASE(KernelInfo);

/* \brief: Get the training C Api
*
* \param[in] version Must be ::ORT_API_VERSION
* \return The ::OrtTrainingApi for the version requested.
* nullptr will be returned and no error message will be printed if the training api is not supported with
* this build.
* nullptr will be returned and an error message will be printed if the provided version is unsupported, for
* example when using a runtime older than the version created with this header file.
*
* \since Version 1.13
*/
Expand Down Expand Up @@ -3890,7 +3897,7 @@ struct OrtApi {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.14.
* \since Version 1.15.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl,
_In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options);
Expand All @@ -3901,7 +3908,7 @@ struct OrtApi {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.14.
* \since Version 1.15.
*/
ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out);

Expand All @@ -3919,7 +3926,7 @@ struct OrtApi {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.14.
* \since Version 1.15.
*/
ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options,
_In_reads_(num_keys) const char* const* provider_options_keys,
Expand All @@ -3938,13 +3945,13 @@ struct OrtApi {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.14.
* \since Version 1.15.
*/
ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr);

/** \brief Release an ::OrtDnnlProviderOptions
*
* \since Version 1.14.
* \since Version 1.15.
*/
void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input);

Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2351,13 +2351,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes

ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) {
#ifdef ENABLE_TRAINING_APIS
return OrtTrainingApis::GetTrainingApi(version);
if (version >= 13 && version <= ORT_API_VERSION)
return OrtTrainingApis::GetTrainingApi(version);

fprintf(stderr, "The given version [%u] is not supported. Training api only supports version 13 to %u.\n",
version, ORT_API_VERSION);
return nullptr;
#else

ORT_UNUSED_PARAMETER(version);
fprintf(stderr,
"Training APIs are not supported with this build. Please build onnxruntime "
"from source with the build flags enable_training_apis to retrieve the training APIs.\n");

return nullptr;
#endif
Expand Down Expand Up @@ -2407,7 +2409,7 @@ Second example, if we wanted to add and remove some members, we'd do this:
In GetApi we now make it return ort_api_3 for version 3.
*/

static constexpr OrtApi ort_api_1_to_14 = {
static constexpr OrtApi ort_api_1_to_15 = {
// NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering.

// Shipped as version 1 - DO NOT MODIFY (see above text for more information)
Expand Down Expand Up @@ -2731,7 +2733,7 @@ static_assert(std::string_view(ORT_VERSION) == "1.15.0",

ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
if (version >= 1 && version <= ORT_API_VERSION)
return &ort_api_1_to_14;
return &ort_api_1_to_15;

fprintf(stderr, "The given version [%u] is not supported, only version 1 to %u is supported in this build.\n",
version, ORT_API_VERSION);
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,4 +443,3 @@ ORT_API_STATUS_IMPL(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOp
ORT_API(void, ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProviderOptions*);

} // namespace OrtApis

0 comments on commit 3d8fa4d

Please sign in to comment.