diff --git a/winml/api/Windows.AI.MachineLearning.idl b/winml/api/Windows.AI.MachineLearning.idl index 2b55fa8c7a95c..2ca2145aacda6 100644 --- a/winml/api/Windows.AI.MachineLearning.idl +++ b/winml/api/Windows.AI.MachineLearning.idl @@ -33,7 +33,7 @@ import "windows.storage.idl"; namespace ROOT_NS.AI.MachineLearning { - [contractversion(5)] + [contractversion(6)] apicontract MachineLearningContract{}; //! Forward declarations @@ -104,22 +104,42 @@ namespace ROOT_NS.AI.MachineLearning //! Loads an ONNX model from a stream asynchronously. [remote_async] static Windows.Foundation.IAsyncOperation LoadFromStreamAsync(Windows.Storage.Streams.IRandomAccessStreamReference modelStream); + //! Loads an ONNX model from a buffer asynchronously. + [contract(MachineLearningContract, 6)] + { + [remote_async] + static Windows.Foundation.IAsyncOperation LoadFromBufferAsync(Windows.Storage.Streams.IBuffer modelBuffer); + } //! Loads an ONNX model from a file on disk. static LearningModel LoadFromFilePath(String filePath); //! Loads an ONNX model from a stream. static LearningModel LoadFromStream(Windows.Storage.Streams.IRandomAccessStreamReference modelStream); - + //! Loads an ONNX model from a buffer. + [contract(MachineLearningContract, 6)] + { + static LearningModel LoadFromBuffer(Windows.Storage.Streams.IBuffer modelBuffer); + } //! Loads an ONNX model from a StorageFile asynchronously. [remote_async] [method_name("LoadFromStorageFileWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation LoadFromStorageFileAsync(Windows.Storage.IStorageFile modelFile, ILearningModelOperatorProvider operatorProvider); //! Loads an ONNX model from a stream asynchronously. [remote_async] [method_name("LoadFromStreamWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation LoadFromStreamAsync(Windows.Storage.Streams.IRandomAccessStreamReference modelStream, ILearningModelOperatorProvider operatorProvider); - //! Loads an ONNX model from a file on disk. + //! Loads an ONNX model from a buffer asynchronously. + [contract(MachineLearningContract, 6)] + { + [remote_async] + [method_name("LoadFromBufferWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation LoadFromBufferAsync(Windows.Storage.Streams.IBuffer modelBuffer, ILearningModelOperatorProvider operatorProvider); + } + //! Loads an ONNX model from a file on disk. [method_name("LoadFromFilePathWithOperatorProvider")] static LearningModel LoadFromFilePath(String filePath, ILearningModelOperatorProvider operatorProvider); //! Loads an ONNX model from a stream. [method_name("LoadFromStreamWithOperatorProvider")] static LearningModel LoadFromStream(Windows.Storage.Streams.IRandomAccessStreamReference modelStream, ILearningModelOperatorProvider operatorProvider); - + //! Loads an ONNX model from a buffer. + [contract(MachineLearningContract, 6)] + { + [method_name("LoadFromBufferWithOperatorProvider")] static LearningModel LoadFromBuffer(Windows.Storage.Streams.IBuffer modelBuffer, ILearningModelOperatorProvider operatorProvider); + } //! The name of the model author. String Author{ get; }; //! The name of the model. diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index b0b988adf8a0e..787c2afd62b68 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -165,6 +165,30 @@ LearningModel::LearningModel( } WINML_CATCH_ALL +static HRESULT CreateModelFromBuffer( + _winml::IEngineFactory* engine_factory, + const wss::IBuffer buffer, + _winml::IModel** model) { + + size_t len = buffer.Length(); + if (FAILED(engine_factory->CreateModel((void*)buffer.data(), len, model))) { + WINML_THROW_HR(E_INVALIDARG); + } + + return S_OK; +} + +LearningModel::LearningModel( + const wss::IBuffer buffer, + const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { + _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); + + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(CreateModelFromBuffer(engine_factory_.get(), buffer, model_.put())); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); +} +WINML_CATCH_ALL + hstring LearningModel::Author() try { const char* out; @@ -293,6 +317,20 @@ LearningModel::LoadFromStreamAsync( return make(model_stream, provider); } +wf::IAsyncOperation +LearningModel::LoadFromBufferAsync( + wss::IBuffer const model_buffer) { + return LoadFromBufferAsync(model_buffer, nullptr); +} + +wf::IAsyncOperation +LearningModel::LoadFromBufferAsync( + wss::IBuffer const model_buffer, + winml::ILearningModelOperatorProvider const provider) { + co_await resume_background(); + return make(model_buffer, provider); +} + winml::LearningModel LearningModel::LoadFromFilePath( hstring const& path) try { @@ -323,6 +361,21 @@ LearningModel::LoadFromStream( } WINML_CATCH_ALL +winml::LearningModel +LearningModel::LoadFromBuffer( + wss::IBuffer const model_buffer) try { + return LoadFromBuffer(model_buffer, nullptr); +} +WINML_CATCH_ALL + +winml::LearningModel +LearningModel::LoadFromBuffer( + wss::IBuffer const model_buffer, + winml::ILearningModelOperatorProvider const provider) try { + return make(model_buffer, provider); +} +WINML_CATCH_ALL + _winml::IModel* LearningModel::DetachModel() { com_ptr<_winml::IModel> detached_model; diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index 5510521e84e1c..0a7975a144709 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -24,6 +24,10 @@ struct LearningModel : LearningModelT { LearningModel( const wss::IRandomAccessStreamReference stream, const winml::ILearningModelOperatorProvider operator_provider); + + LearningModel( + const wss::IBuffer stream, + const winml::ILearningModelOperatorProvider operator_provider); LearningModel( _winml::IEngineFactory* engine_factory, @@ -77,6 +81,15 @@ struct LearningModel : LearningModelT { wss::IRandomAccessStreamReference const stream, winml::ILearningModelOperatorProvider const operator_provider); + static wf::IAsyncOperation + LoadFromBufferAsync( + wss::IBuffer const buffer); + + static wf::IAsyncOperation + LoadFromBufferAsync( + wss::IBuffer const buffer, + winml::ILearningModelOperatorProvider const operator_provider); + static winml::LearningModel LoadFromFilePath( hstring const& path); @@ -89,11 +102,20 @@ struct LearningModel : LearningModelT { static winml::LearningModel LoadFromStream( wss::IRandomAccessStreamReference const stream); + + static winml::LearningModel + LoadFromBuffer( + wss::IBuffer const buffer); static winml::LearningModel LoadFromStream( wss::IRandomAccessStreamReference const stream, winml::ILearningModelOperatorProvider const operator_provider); + + static winml::LearningModel + LoadFromBuffer( + wss::IBuffer const buffer, + winml::ILearningModelOperatorProvider const operator_provider); public: /* Non-ABI methods */ diff --git a/winml/test/api/LearningModelAPITest.cpp b/winml/test/api/LearningModelAPITest.cpp index 447b98137bb91..1354baec83b4b 100644 --- a/winml/test/api/LearningModelAPITest.cpp +++ b/winml/test/api/LearningModelAPITest.cpp @@ -90,6 +90,20 @@ static void CreateModelFromIStream() { WINML_EXPECT_EQUAL(L"onnx-caffe2", author); } +static void CreateModelFromIBuffer() { + std::wstring path = FileHelpers::GetModulePath() + L"squeezenet_modifiedforruntimestests.onnx"; + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); + IBuffer buffer = FileIO::ReadBufferAsync(storageFile).get(); + + LearningModel learningModel = nullptr; + WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromBufferAsync(buffer).get()); + WINML_EXPECT_TRUE(learningModel != nullptr); + + // check the author so we know the model was populated correctly. + std::wstring author(learningModel.Author()); + WINML_EXPECT_EQUAL(L"onnx-caffe2", author); +} + static void ModelGetAuthor() { LearningModel learningModel = nullptr; WINML_EXPECT_NO_THROW(APITest::LoadModel(L"squeezenet_modifiedforruntimestests.onnx", learningModel)); @@ -323,6 +337,7 @@ const LearningModelApiTestsApi& getapi() { CreateModelFromIStorage, CreateModelFromIStorageOutsideCwd, CreateModelFromIStream, + CreateModelFromIBuffer, ModelGetAuthor, ModelGetName, ModelGetDomain, diff --git a/winml/test/api/LearningModelAPITest.h b/winml/test/api/LearningModelAPITest.h index c87bdd6144511..df23461908dbe 100644 --- a/winml/test/api/LearningModelAPITest.h +++ b/winml/test/api/LearningModelAPITest.h @@ -11,6 +11,7 @@ struct LearningModelApiTestsApi VoidTest CreateModelFromIStorage; VoidTest CreateModelFromIStorageOutsideCwd; VoidTest CreateModelFromIStream; + VoidTest CreateModelFromIBuffer; VoidTest ModelGetAuthor; VoidTest ModelGetName; VoidTest ModelGetDomain; @@ -36,6 +37,7 @@ WINML_TEST(LearningModelAPITests, CreateModelFileNotFound) WINML_TEST(LearningModelAPITests, CreateModelFromIStorage) WINML_TEST(LearningModelAPITests, CreateModelFromIStorageOutsideCwd) WINML_TEST(LearningModelAPITests, CreateModelFromIStream) +WINML_TEST(LearningModelAPITests, CreateModelFromIBuffer) WINML_TEST(LearningModelAPITests, ModelGetAuthor) WINML_TEST(LearningModelAPITests, ModelGetName) WINML_TEST(LearningModelAPITests, ModelGetDomain)