Skip to content

Commit

Permalink
Improve consistency. Update some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Dec 23, 2024
1 parent dece8b8 commit 351d12d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5304,7 +5304,7 @@ struct OrtModelBuilderApi {
*
* \since Version 1.21.
*/
ORT_API2_STATUS(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model);
ORT_API2_STATUS(ApplyModelToModelBuilderSession, _In_ OrtSession* session, _In_ OrtModel* model);

/** \brief Finalize the Model Builder session.
*
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ inline void SessionImpl<T>::SetEpDynamicOptions(const char* const* keys, const c
template <typename T>
inline void SessionImpl<T>::FinalizeModelBuilderSession(const ModelBuilderAPI::Model& model, const SessionOptions& options,
OrtPrepackedWeightsContainer* prepacked_weights_container) {
ThrowOnError(GetModelBuilderApi().ApplyModelToSession(this->p_, model));
ThrowOnError(GetModelBuilderApi().ApplyModelToModelBuilderSession(this->p_, model));
ThrowOnError(GetModelBuilderApi().FinalizeModelBuilderSession(this->p_, options, prepacked_weights_container));
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/model_builder_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ORT_API_STATUS_IMPL(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env,
_In_ const OrtSessionOptions* options,
_Outptr_ OrtSession** out);

ORT_API_STATUS_IMPL(ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model);
ORT_API_STATUS_IMPL(ApplyModelToModelBuilderSession, _In_ OrtSession* session, _In_ OrtModel* model);

ORT_API_STATUS_IMPL(FinalizeModelBuilderSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options,
_Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container);
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/session/model_builder_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateSessionFromModel, _In_ const OrtEn
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options,
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModelBuilderSession,
_In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options,
_Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> session;
Expand Down Expand Up @@ -288,7 +288,8 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModelBuilderSessionFromArray, _In_
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtModelBuilderAPI::ApplyModelToSession, _In_ OrtSession* session, _In_ OrtModel* model) {
ORT_API_STATUS_IMPL(OrtModelBuilderAPI::ApplyModelToModelBuilderSession,
_In_ OrtSession* session, _In_ OrtModel* model) {
API_IMPL_BEGIN
auto sess = reinterpret_cast<onnxruntime::InferenceSession*>(session);
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model));
Expand Down Expand Up @@ -332,7 +333,7 @@ static constexpr OrtModelBuilderApi ort_graph_api = {

&OrtModelBuilderAPI::CreateModelBuilderSession,
&OrtModelBuilderAPI::CreateModelBuilderSessionFromArray,
&OrtModelBuilderAPI::ApplyModelToSession,
&OrtModelBuilderAPI::ApplyModelToModelBuilderSession,
&OrtModelBuilderAPI::FinalizeModelBuilderSession,
};

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/shared_lib/test_model_builder_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
SessionOptions so;

// Set this to save the model if you want to debug.
so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx"));
// so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx"));

Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so);

Expand All @@ -351,7 +351,7 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
// the original graph is unchanged. nodes can be added before/after it. initializers can be added.
// new nodes must conform to the original domain:opset of the model.
// additional operator domain:opset pairs can be added.
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets;
std::vector<ModelBuilderAPI::Model::DomainOpsetPair> opsets; // no additional opsets required
ModelBuilderAPI::Model model(opsets);

std::vector<std::string> input_names = session.GetInputNames();
Expand Down

0 comments on commit 351d12d

Please sign in to comment.