Skip to content

Commit

Permalink
add load from buffer (#10162)
Browse files Browse the repository at this point in the history
* Add LoadFromBuffer API
  • Loading branch information
Jingqiao Fu authored Jan 10, 2022
1 parent edd1a2c commit 5cd57bb
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 4 deletions.
28 changes: 24 additions & 4 deletions winml/api/Windows.AI.MachineLearning.idl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import "windows.storage.idl";

namespace ROOT_NS.AI.MachineLearning
{
[contractversion(5)]
[contractversion(6)]
apicontract MachineLearningContract{};

//! Forward declarations
Expand Down Expand Up @@ -104,22 +104,42 @@ namespace ROOT_NS.AI.MachineLearning
//! Loads an ONNX model from a stream asynchronously.
[remote_async]
static Windows.Foundation.IAsyncOperation<LearningModel> LoadFromStreamAsync(Windows.Storage.Streams.IRandomAccessStreamReference modelStream);
//! Loads an ONNX model from a buffer asynchronously.
[contract(MachineLearningContract, 6)]
{
[remote_async]
static Windows.Foundation.IAsyncOperation<LearningModel> LoadFromBufferAsync(Windows.Storage.Streams.IBuffer modelBuffer);
}
//! Loads an ONNX model from a file on disk.
static LearningModel LoadFromFilePath(String filePath);
//! Loads an ONNX model from a stream.
static LearningModel LoadFromStream(Windows.Storage.Streams.IRandomAccessStreamReference modelStream);

//! Loads an ONNX model from a buffer.
[contract(MachineLearningContract, 6)]
{
static LearningModel LoadFromBuffer(Windows.Storage.Streams.IBuffer modelBuffer);
}
//! Loads an ONNX model from a StorageFile asynchronously.
[remote_async]
[method_name("LoadFromStorageFileWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation<LearningModel> LoadFromStorageFileAsync(Windows.Storage.IStorageFile modelFile, ILearningModelOperatorProvider operatorProvider);
//! Loads an ONNX model from a stream asynchronously.
[remote_async]
[method_name("LoadFromStreamWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation<LearningModel> LoadFromStreamAsync(Windows.Storage.Streams.IRandomAccessStreamReference modelStream, ILearningModelOperatorProvider operatorProvider);
//! Loads an ONNX model from a file on disk.
//! Loads an ONNX model from a buffer asynchronously.
[contract(MachineLearningContract, 6)]
{
[remote_async]
[method_name("LoadFromBufferWithOperatorProviderAsync")] static Windows.Foundation.IAsyncOperation<LearningModel> LoadFromBufferAsync(Windows.Storage.Streams.IBuffer modelBuffer, ILearningModelOperatorProvider operatorProvider);
}
//! Loads an ONNX model from a file on disk.
[method_name("LoadFromFilePathWithOperatorProvider")] static LearningModel LoadFromFilePath(String filePath, ILearningModelOperatorProvider operatorProvider);
//! Loads an ONNX model from a stream.
[method_name("LoadFromStreamWithOperatorProvider")] static LearningModel LoadFromStream(Windows.Storage.Streams.IRandomAccessStreamReference modelStream, ILearningModelOperatorProvider operatorProvider);

//! Loads an ONNX model from a buffer.
[contract(MachineLearningContract, 6)]
{
[method_name("LoadFromBufferWithOperatorProvider")] static LearningModel LoadFromBuffer(Windows.Storage.Streams.IBuffer modelBuffer, ILearningModelOperatorProvider operatorProvider);
}
//! The name of the model author.
String Author{ get; };
//! The name of the model.
Expand Down
53 changes: 53 additions & 0 deletions winml/lib/Api/LearningModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,30 @@ LearningModel::LearningModel(
}
WINML_CATCH_ALL

static HRESULT CreateModelFromBuffer(
_winml::IEngineFactory* engine_factory,
const wss::IBuffer buffer,
_winml::IModel** model) {

size_t len = buffer.Length();
if (FAILED(engine_factory->CreateModel((void*)buffer.data(), len, model))) {
WINML_THROW_HR(E_INVALIDARG);
}

return S_OK;
}

LearningModel::LearningModel(
const wss::IBuffer buffer,
const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) {
_winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad);

WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put()));
WINML_THROW_IF_FAILED(CreateModelFromBuffer(engine_factory_.get(), buffer, model_.put()));
WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put()));
}
WINML_CATCH_ALL

hstring
LearningModel::Author() try {
const char* out;
Expand Down Expand Up @@ -293,6 +317,20 @@ LearningModel::LoadFromStreamAsync(
return make<LearningModel>(model_stream, provider);
}

wf::IAsyncOperation<winml::LearningModel>
LearningModel::LoadFromBufferAsync(
wss::IBuffer const model_buffer) {
return LoadFromBufferAsync(model_buffer, nullptr);
}

wf::IAsyncOperation<winml::LearningModel>
LearningModel::LoadFromBufferAsync(
wss::IBuffer const model_buffer,
winml::ILearningModelOperatorProvider const provider) {
co_await resume_background();
return make<LearningModel>(model_buffer, provider);
}

winml::LearningModel
LearningModel::LoadFromFilePath(
hstring const& path) try {
Expand Down Expand Up @@ -323,6 +361,21 @@ LearningModel::LoadFromStream(
}
WINML_CATCH_ALL

winml::LearningModel
LearningModel::LoadFromBuffer(
wss::IBuffer const model_buffer) try {
return LoadFromBuffer(model_buffer, nullptr);
}
WINML_CATCH_ALL

winml::LearningModel
LearningModel::LoadFromBuffer(
wss::IBuffer const model_buffer,
winml::ILearningModelOperatorProvider const provider) try {
return make<LearningModel>(model_buffer, provider);
}
WINML_CATCH_ALL

_winml::IModel*
LearningModel::DetachModel() {
com_ptr<_winml::IModel> detached_model;
Expand Down
22 changes: 22 additions & 0 deletions winml/lib/Api/LearningModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct LearningModel : LearningModelT<LearningModel> {
LearningModel(
const wss::IRandomAccessStreamReference stream,
const winml::ILearningModelOperatorProvider operator_provider);

LearningModel(
const wss::IBuffer stream,
const winml::ILearningModelOperatorProvider operator_provider);

LearningModel(
_winml::IEngineFactory* engine_factory,
Expand Down Expand Up @@ -77,6 +81,15 @@ struct LearningModel : LearningModelT<LearningModel> {
wss::IRandomAccessStreamReference const stream,
winml::ILearningModelOperatorProvider const operator_provider);

static wf::IAsyncOperation<winml::LearningModel>
LoadFromBufferAsync(
wss::IBuffer const buffer);

static wf::IAsyncOperation<winml::LearningModel>
LoadFromBufferAsync(
wss::IBuffer const buffer,
winml::ILearningModelOperatorProvider const operator_provider);

static winml::LearningModel
LoadFromFilePath(
hstring const& path);
Expand All @@ -89,11 +102,20 @@ struct LearningModel : LearningModelT<LearningModel> {
static winml::LearningModel
LoadFromStream(
wss::IRandomAccessStreamReference const stream);

static winml::LearningModel
LoadFromBuffer(
wss::IBuffer const buffer);

static winml::LearningModel
LoadFromStream(
wss::IRandomAccessStreamReference const stream,
winml::ILearningModelOperatorProvider const operator_provider);

static winml::LearningModel
LoadFromBuffer(
wss::IBuffer const buffer,
winml::ILearningModelOperatorProvider const operator_provider);

public:
/* Non-ABI methods */
Expand Down
15 changes: 15 additions & 0 deletions winml/test/api/LearningModelAPITest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ static void CreateModelFromIStream() {
WINML_EXPECT_EQUAL(L"onnx-caffe2", author);
}

static void CreateModelFromIBuffer() {
std::wstring path = FileHelpers::GetModulePath() + L"squeezenet_modifiedforruntimestests.onnx";
auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get();
IBuffer buffer = FileIO::ReadBufferAsync(storageFile).get();

LearningModel learningModel = nullptr;
WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromBufferAsync(buffer).get());
WINML_EXPECT_TRUE(learningModel != nullptr);

// check the author so we know the model was populated correctly.
std::wstring author(learningModel.Author());
WINML_EXPECT_EQUAL(L"onnx-caffe2", author);
}

static void ModelGetAuthor() {
LearningModel learningModel = nullptr;
WINML_EXPECT_NO_THROW(APITest::LoadModel(L"squeezenet_modifiedforruntimestests.onnx", learningModel));
Expand Down Expand Up @@ -323,6 +337,7 @@ const LearningModelApiTestsApi& getapi() {
CreateModelFromIStorage,
CreateModelFromIStorageOutsideCwd,
CreateModelFromIStream,
CreateModelFromIBuffer,
ModelGetAuthor,
ModelGetName,
ModelGetDomain,
Expand Down
2 changes: 2 additions & 0 deletions winml/test/api/LearningModelAPITest.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ struct LearningModelApiTestsApi
VoidTest CreateModelFromIStorage;
VoidTest CreateModelFromIStorageOutsideCwd;
VoidTest CreateModelFromIStream;
VoidTest CreateModelFromIBuffer;
VoidTest ModelGetAuthor;
VoidTest ModelGetName;
VoidTest ModelGetDomain;
Expand All @@ -36,6 +37,7 @@ WINML_TEST(LearningModelAPITests, CreateModelFileNotFound)
WINML_TEST(LearningModelAPITests, CreateModelFromIStorage)
WINML_TEST(LearningModelAPITests, CreateModelFromIStorageOutsideCwd)
WINML_TEST(LearningModelAPITests, CreateModelFromIStream)
WINML_TEST(LearningModelAPITests, CreateModelFromIBuffer)
WINML_TEST(LearningModelAPITests, ModelGetAuthor)
WINML_TEST(LearningModelAPITests, ModelGetName)
WINML_TEST(LearningModelAPITests, ModelGetDomain)
Expand Down

0 comments on commit 5cd57bb

Please sign in to comment.