From 53b0116eea819d2e97b02a8bbcab56777d0883d6 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 17 Jan 2023 12:18:57 +0000 Subject: [PATCH 01/18] Add TuningResultsValidator --- onnxruntime/core/framework/tuning_context.h | 29 ++++ .../core/framework/tuning_context_impl.h | 142 +++++++++++++++++- .../cuda/tunable/cuda_tuning_context.cc | 38 ++++- .../cuda/tunable/cuda_tuning_context.h | 15 ++ .../rocm/tunable/rocm_tuning_context.cc | 64 +++++++- .../rocm/tunable/rocm_tuning_context.h | 17 +++ onnxruntime/test/framework/tunable_op_test.cc | 13 ++ 7 files changed, 315 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 77c15d65b58ca..a9d8e4faa32e7 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -25,6 +25,8 @@ class ITuningContext { virtual TuningResultsManager& GetTuningResultsManager() = 0; virtual const TuningResultsManager& GetTuningResultsManager() const = 0; + + virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; }; class TuningResultsManager { @@ -50,4 +52,31 @@ class TuningResultsManager { std::unordered_map results_; }; +class TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + + std::unordered_map GetAllValidators() const; + Status ValidateAll(const std::unordered_map& to_validate) const; + + protected: + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + virtual std::string GetOrtVersion() const; + virtual Status ValidateOrtVersion(const std::string& value) const; + + virtual std::string GetOrtGitCommit() const; + virtual Status ValidateOrtGitCommit(const std::string& value) const; + + virtual std::string GetOrtBuildConfig() const; + virtual Status ValidateOrtBuildConfig(const std::string& value) const; + + private: + GetValidateFuncs validators_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 50aeee1c6707b..3e6826943c55f 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these +// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these // methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries // are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem, // the EP must has and only has one translation unit include this file. @@ -11,6 +11,10 @@ #pragma once +#include +#include +#include + #include "core/framework/tunable.h" #include "core/framework/tuning_context.h" #include "core/framework/tuning_results.h" @@ -106,4 +110,140 @@ void TuningResultsManager::Clear() { results_ = {}; } +Status CheckMandatoryKeys( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; + + bool passed = true; + std::ostringstream oss; + for (const auto& k : mandatory_keys) { + if (gv_funcs.find(k) == gv_funcs.end()) { + passed = false; + oss << "key=\"" << k << "\" is not registered for Get and Validate. "; + } + + if (to_check.find(k) == to_check.end()) { + passed = false; + oss << "key=\"" << k << "\" is not provided for validation. "; + } + } + ORT_RETURN_IF(!passed, oss.str()); + return Status::OK(); +} + +Status CheckKeysMatching( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + auto get_keys = [](const auto& it) -> std::string { return it.first; }; + std::vector required_keys; + std::vector provided_keys; + std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); + std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); + std::sort(required_keys.begin(), required_keys.end()); + std::sort(provided_keys.begin(), provided_keys.end()); + + std::unordered_set intersection; + std::set_intersection(required_keys.cbegin(), required_keys.cend(), + provided_keys.cbegin(), provided_keys.cend(), + std::inserter(intersection, intersection.end())); + bool matched = true; + std::ostringstream oss; + if (intersection.size() != required_keys.size()) { + matched = false; + for (const auto& k : required_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. "; + } + } + } + if (intersection.size() != provided_keys.size()) { + matched = false; + for (const auto& k : provided_keys) { + if (intersection.find(k) == intersection.end()) { + oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. "; + } + } + } + ORT_RETURN_IF(!matched, oss.str()); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtVersion() const { + return ORT_VERSION; +} + +Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const { + ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch"); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtGitCommit() const { + // TODO: + return ""; +} + +Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const { + // TODO: + ORT_UNUSED_PARAMETER(value); + return Status::OK(); +} + +std::string TuningResultsValidator::GetOrtBuildConfig() const { + return ""; +} + +Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const { + auto current = GetOrtBuildConfig(); + ORT_RETURN_IF(current != value, + "onnxruntime building configuration mismatch: tuning results produced with library \"", + value, "\", current library built with \"", current, "\""); + return Status::OK(); +} + +TuningResultsValidator::TuningResultsValidator() { + RegisterValidator( + "ORT_VERSION", + [this]() { return GetOrtVersion(); }, + [this](auto&& k) { return ValidateOrtVersion(std::forward(k)); }); + + RegisterValidator( + "ORT_GIT_COMMIT", + [this]() { return GetOrtGitCommit(); }, + [this](auto&& k) { return ValidateOrtGitCommit(std::forward(k)); }); + + RegisterValidator( + "ORT_BUILD_CONFIG", + [this]() { return GetOrtBuildConfig(); }, + [this](auto&& k) { return ValidateOrtBuildConfig(std::forward(k)); }); +} + +Status TuningResultsValidator::ValidateAll(const std::unordered_map& to_validate) const { + ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate)); + ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate)); + + for (const auto& [key, value] : to_validate) { + const auto& it = validators_.find(key); + ORT_ENFORCE(it != validators_.cend()); + const ValidateFunc& validator = it->second.second; + ORT_RETURN_IF_ERROR(validator(value)); + } + + return Status::OK(); +} + +std::unordered_map TuningResultsValidator::GetAllValidators() const { + std::unordered_map ret; + for (const auto& [key, get_validate_func_pair] : validators_) { + const GetFunc& getter = get_validate_func_pair.first; + ret[key] = getter(); + } + return ret; +} + +void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { + ORT_ENFORCE(validators_.find(key) == validators_.end()); + validators_[key] = std::make_pair(gf, vf); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index 55c79273b1ced..aca418f2d7f8c 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -14,7 +14,39 @@ namespace onnxruntime { namespace cuda { namespace tunable { -CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {} +std::string GetCudaVersion() { + int version; + CUDA_CALL_THROW(cudaRuntimeGetVersion(&version)); + return std::to_string(version); +} + +Status ValidateCudaVersion(const std::string& value) { + auto current = GetCudaVersion(); + ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value, + ", onnxruntime currently run with CUDA ", current); + return Status::OK(); +} + +std::string CudaTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) { + RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; @@ -38,6 +70,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h index df47c53c7a4b6..10d0782f5b10c 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h @@ -15,6 +15,18 @@ class CUDAExecutionProvider; namespace cuda { namespace tunable { +class CudaTuningResultsValidator : public TuningResultsValidator { + public: + CudaTuningResultsValidator(CUDAExecutionProvider* ep); + + protected: + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + CUDAExecutionProvider* ep_; // non-owning handle +}; + class CudaTuningContext : public ITuningContext { public: explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + CudaTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index c2888d76643f7..7b7c855c00715 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -14,7 +14,65 @@ namespace onnxruntime { namespace rocm { namespace tunable { -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider*, TunableOpInfo* info) : info_(info) {} +std::string GetHipVersion() { + int version; + HIP_CALL_THROW(hipRuntimeGetVersion(&version)); + return std::to_string(version); +} + +Status ValidateHipVersion(const std::string& value) { + auto current = GetHipVersion(); + ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value, + ", onnxruntime currently run with HIP ", current); + return Status::OK(); +} + +std::string GetRocBlasVersion() { + char buf[64]; + ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256)); + buf[63] = '\0'; + return buf; +} + +Status ValidateRocBlasVersion(const std::string& value) { + auto current = GetRocBlasVersion(); + ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value, + ", onnxruntime currently run with rocblas ", current); + return Status::OK(); +} + +std::string RocmTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { + RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); + RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + +std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { + std::ostringstream oss; + oss << "USE_CK=" << USE_COMPOSABLE_KERNEL << "|"; +#ifdef USE_ROCBLAS_EXTENSION_API + oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|"; +#else + oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|"; +#endif + return oss.str(); +} + +RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; @@ -38,6 +96,10 @@ const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const { return manager_; } +const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const { + return validator_; +} + } // namespace tunable } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h index d6eb0886ddee8..cffc0e5614f16 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h @@ -15,6 +15,20 @@ class ROCMExecutionProvider; namespace rocm { namespace tunable { +class RocmTuningResultsValidator : public TuningResultsValidator { + public: + RocmTuningResultsValidator(ROCMExecutionProvider* ep); + + protected: + std::string GetOrtBuildConfig() const override; + + std::string GetDeviceModel() const; + Status ValidateDeviceModel(const std::string& value) const; + + private: + ROCMExecutionProvider* ep_; // non-owning handle +}; + class RocmTuningContext : public ITuningContext { public: explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info); @@ -26,9 +40,12 @@ class RocmTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; + const TuningResultsValidator& GetTuningResultsValidator() const override; + private: TunableOpInfo* info_; // non-owning handle TuningResultsManager manager_; + RocmTuningResultsValidator validator_; }; } // namespace tunable diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 8d3535cabc045..0d62821278862 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -20,6 +20,16 @@ namespace { // test on CPU and it does not use stream using StreamT = void*; +class TestTuningResultsValidator : public TuningResultsValidator { + public: + TestTuningResultsValidator() = default; + + protected: + std::string GetOrtBuildConfig() const override { + return "TEST_BUILD"; + } +}; + class TestTuningContext : public ITuningContext { public: void EnableTunableOp() override { tuning_enabled_ = true; } @@ -29,9 +39,12 @@ class TestTuningContext : public ITuningContext { TuningResultsManager& GetTuningResultsManager() override { return manager_; } const TuningResultsManager& GetTuningResultsManager() const override { return manager_; } + const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; } + private: bool tuning_enabled_{false}; TuningResultsManager manager_{}; + TestTuningResultsValidator validator_{}; }; class TestEP : public IExecutionProvider { From ac844798f6e1c54bfc62929d557c8f6689ea0a5f Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 8 Feb 2023 05:53:34 +0000 Subject: [PATCH 02/18] Add session API for getting and setting tuning resutls. Add embeded tuning results auto loading on session init --- onnxruntime/core/framework/tunable.h | 12 ++++- onnxruntime/core/framework/tuning_context.h | 8 +++ .../core/framework/tuning_context_impl.h | 37 +++++++++++++ .../cuda/tunable/cuda_tuning_context.cc | 3 +- .../rocm/tunable/rocm_tuning_context.cc | 3 +- onnxruntime/core/session/inference_session.cc | 54 +++++++++++++++++++ onnxruntime/core/session/inference_session.h | 18 +++++++ .../core/session/inference_session_utils.cc | 30 +++++++++++ .../core/session/inference_session_utils.h | 5 ++ .../onnxruntime_inference_collection.py | 6 +++ .../python/onnxruntime_pybind_state.cc | 39 ++++++++++++++ onnxruntime/test/framework/tunable_op_test.cc | 7 ++- 12 files changed, 217 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 29b4a443e2e17..e36fb85266388 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -136,12 +136,20 @@ class TunableOp { ITuningContext* ctx = params->TuningContext(); if (ctx->IsTunableOpEnabled()) { auto& mgr = ctx->GetTuningResultsManager(); - id = mgr.Lookup(Signature(), params->Signature()); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + id = mgr.Lookup(op_sig, params_sig); + if (id > static_cast(ops_.size())) { + LOGS_DEFAULT(FATAL) << "Invalid TunableOp kernel id for " << op_sig + << ", id:" << id << ", registered op:" << ops_.size(); + mgr.Delete(op_sig, params_sig); + id = -1; + } if (id < 0) { auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); - mgr.Add(Signature(), params->Signature(), id); + mgr.Add(op_sig, params_sig, id); } } ORT_RETURN_IF_ERROR(ops_[id](params)); diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index a9d8e4faa32e7..b32f7f8b324a1 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -17,6 +17,7 @@ class TuningResultsValidator; class ITuningContext { public: + explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {} virtual ~ITuningContext() = default; virtual void EnableTunableOp() = 0; @@ -27,6 +28,12 @@ class ITuningContext { virtual const TuningResultsManager& GetTuningResultsManager() const = 0; virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; + + virtual TuningResults SaveTuningResults() const; + virtual Status LoadTuningResults(const TuningResults& tr); + + protected: + IExecutionProvider* ep_; }; class TuningResultsManager { @@ -38,6 +45,7 @@ class TuningResultsManager { int Lookup(const std::string& op_signature, const std::string& params_signature) const; void Add(const std::string& op_signature, const std::string& params_signature, int best_id); + void Delete(const std::string& op_signature, const std::string& params_signature); void Load(const std::unordered_map& results_to_load); std::unordered_map Dump() const; diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 3e6826943c55f..2de458765c703 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -21,6 +21,22 @@ namespace onnxruntime { +TuningResults ITuningContext::SaveTuningResults() const { + TuningResults tr; + tr.ep = ep_->Type(); + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +Status ITuningContext::LoadTuningResults(const TuningResults& tr) { + ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch"); + LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep; + ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return Status::OK(); +} + KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const { std::scoped_lock l{lock_}; auto it = results_.find(op_signature); @@ -74,6 +90,22 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin AddImpl(op_signature, params_signature, best_id, it->second); } +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + it->second.erase(it2); +} + std::unordered_map TuningResultsManager::Dump() const { std::scoped_lock l{lock_}; return results_; @@ -95,6 +127,11 @@ void DisjointMergeImpl( } void TuningResultsManager::Load(const std::unordered_map& results_to_load) { + for(const auto& [op_sig, kernel_map]: results_to_load) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_sig << " " << param_sig << " " << kernel_id; + } + } std::scoped_lock l{lock_}; for (const auto& [op_signature, kernel_map] : results_to_load) { DisjointMergeImpl(op_signature, kernel_map, results_); diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index aca418f2d7f8c..e2a4bd694dd49 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -46,7 +46,8 @@ CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep [this](const std::string& value) { return ValidateDeviceModel(value); }); } -CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} +CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index 7b7c855c00715..b2a8134c708c5 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -72,7 +72,8 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { return oss.str(); } -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} +RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 79068a0271f28..d4674a717b8f6 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1550,6 +1550,16 @@ common::Status InferenceSession::Initialize() { } } + std::vector tuning_results; + ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata(model_metadata_, tuning_results)); + if(!tuning_results.empty()) { + ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results)); + } + else { + LOGS(*session_logger_, WARNING) << "Got empty tuning results."; + } + + return status; } #if defined(_MSC_VER) && !defined(__clang__) @@ -2182,6 +2192,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const { return session_profiler_; } +#if !defined(ORT_MINIMAL_BUILD) +std::vector InferenceSession::GetTuningResults() const { + std::vector ret; + for (const auto& provider : execution_providers_) { + const auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx != nullptr) { + ret.emplace_back(tuning_ctx->SaveTuningResults()); + } + } + return ret; +} + +Status InferenceSession::SetTuningResults(const std::vector& trs, bool error_on_invalid) { + std::string msg; + + for (size_t i = 0; i < trs.size(); i++) { + const auto& tr = trs[i]; + auto* provider = execution_providers_.Get(tr.ep); + if (provider == nullptr) { + msg = MakeString("Cannot find execution provider ", tr.ep); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + continue; + } + + auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx == nullptr) { + msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp."); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + continue; + } + + auto status = tuning_ctx->LoadTuningResults(tr); + if (!status.IsOK()) { + msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + } + } + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f01523c923385..95b0dde281cda 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" +#include "core/framework/tuning_results.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -448,6 +449,23 @@ class InferenceSession { */ const profiling::Profiler& GetProfiling() const; +#if !defined(ORT_MINIMAL_BUILD) + /** + * Get the TuningResults of TunableOp for every execution providers. + * @return The TuningResults of each execution provider. + */ + std::vector GetTuningResults() const; + + /** + * Set the TuningResults back to each execution provider. Mainly for offline tuning. + * @param trs is the list of TuningResults to be loaded. + * @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced. + * @return OK if success. + */ + Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false); +#endif + + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) MemoryProfiler& GetMemoryProfiler() { return memory_profiler_; diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index d938228c3af6f..5033f1ed90bef 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options, return Status::OK(); } +// This function is called by nlohmann/json +void from_json(const json& j, TuningResults& trs) { + j.at("ep").get_to(trs.ep); + j.at("results").get_to(trs.results); + j.at("validators").get_to(trs.validators); +} + //--------------------------------------------------- //--- end of session options related helpers --- //--------------------------------------------------- @@ -227,6 +234,29 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options "Parsing RunOptions from ModelProto is not supported yet"); } +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results) { + results.clear(); + auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); + if (it == metadata.custom_metadata_map.end()) { + return Status::OK(); + } + + LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + + ORT_TRY { + auto parsed_tuning_results_json = json::parse(it->second); + results = parsed_tuning_results_json.get>(); + } + ORT_CATCH(const std::exception& e) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, + "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + } + + return Status::OK(); +} + } // namespace inference_session_utils } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 3b021bfd8642a..3d884cc1b991d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -12,6 +12,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "core/framework/session_options.h" +#include "core/framework/tuning_results.h" #include "core/common/common.h" #include "nlohmann/json.hpp" using json = nlohmann::json; @@ -30,6 +31,7 @@ static constexpr const char* kOrtLoadConfigFromModelEnvVar = "ORT_LOAD_CONFIG_FR // static constexpr const char* kOrtConfigKey = "ort_config"; static constexpr const char* kSessionOptionsKey = "session_options"; +static constexpr const char* kTuningResultsKeys = "tuning_results"; class JsonConfigParser { public: @@ -56,6 +58,9 @@ class JsonConfigParser { bool is_ort_config_json_available_ = false; }; +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results); + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace inference_session_utils diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f733c13b6d085..5108f3245213a 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -286,6 +286,12 @@ def run_with_iobinding(self, iobinding, run_options=None): """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) + def get_tuning_results(self): + return self._sess.get_tuning_results() + + def set_tuning_results(self, results, error_on_invalid=False): + return self._sess.set_tuning_results(results, error_on_invalid) + def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices): """ Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 490eb92afcd2b..e4118e73a6164 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1664,6 +1664,45 @@ including arg name, arg type (contains both type and shape).)pbdoc") status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + }) + .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { + py::list ret; + for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) { + py::dict py_trs; + py_trs["ep"] = trs.ep; + py_trs["results"] = trs.results; + py_trs["validators"] = trs.validators; + ret.append(std::move(py_trs)); + } + + return ret; + }) + .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { + std::vector tuning_results; + for (auto handle: results) { + auto py_trs = handle.cast(); + TuningResults trs; + trs.ep = py_trs["ep"].cast(); + + for (const auto& [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { + KernelMap kernel_map; + for (const auto& [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { + kernel_map[py_params_sig.cast()] = py_kernel_id.cast(); + } + trs.results[py_op_sig.cast()] = kernel_map; + } + + for (const auto& [k, v]: py_trs["validators"].cast()) { + trs.validators[k.cast()] = v.cast(); + } + + tuning_results.emplace_back(std::move(trs)); + } + + Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid); + if (!status.IsOK()) { + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + } }); py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 0d62821278862..5d837bd4e9f0c 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -32,6 +32,8 @@ class TestTuningResultsValidator : public TuningResultsValidator { class TestTuningContext : public ITuningContext { public: + using ITuningContext::ITuningContext; + void EnableTunableOp() override { tuning_enabled_ = true; } void DisableTunableOp() override { tuning_enabled_ = false; } bool IsTunableOpEnabled() const override { return tuning_enabled_; } @@ -41,6 +43,8 @@ class TestTuningContext : public ITuningContext { const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; } + void ClearCache() { manager_.Clear(); } + private: bool tuning_enabled_{false}; TuningResultsManager manager_{}; @@ -49,7 +53,7 @@ class TestTuningContext : public ITuningContext { class TestEP : public IExecutionProvider { static constexpr const char* kEPType = "TestEP"; - TestTuningContext tuning_ctx_{}; + TestTuningContext tuning_ctx_{this}; public: TestEP() : IExecutionProvider{kEPType, true} {} @@ -58,6 +62,7 @@ class TestEP : public IExecutionProvider { return const_cast(&tuning_ctx_); } + void ClearCache() { tuning_ctx_.ClearCache(); } }; class TestTimer : public ITimer { From 255ba248e86a90d8249d610ef83a2f55583100ee Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Fri, 3 Feb 2023 07:24:02 +0000 Subject: [PATCH 03/18] Add tool for offline tuning --- tools/python/offline_tuning.py | 165 +++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 tools/python/offline_tuning.py diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py new file mode 100644 index 0000000000000..c2349b4706fb2 --- /dev/null +++ b/tools/python/offline_tuning.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import copy +import itertools +import json +import sys +from collections import OrderedDict +from pprint import pprint +from typing import Any, Dict, List + +import onnx + +TuningResults = Dict[str, Any] + +_tuning_results_key = "tuning_results" + + +def _find_tuning_results_in_props(metadata_props): + for idx, prop in enumerate(metadata_props): + if prop.key == _tuning_results_key: + return idx + return -1 + + +def extract(onnx: onnx.ModelProto): + idx = _find_tuning_results_in_props(onnx.metadata_props) + if idx < 0: + return None + + tuning_results_prop = onnx.metadata_props[idx] + return json.loads(tuning_results_prop.value) + + +def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): + idx = _find_tuning_results_in_props(model.metadata_props) + assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!" + + if idx >= 0: + model.metadata_props.pop(idx) + + entry = model.metadata_props.add() + entry.key = _tuning_results_key + entry.value = json.dumps(tuning_results) + return model + + +class Merger: + class EpAndValidators: + def __init__(self, ep: str, validators: Dict[str, str]): + self.ep = ep + self.validators = copy.deepcopy(validators) + self.key = (ep, tuple(sorted(validators.items()))) + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.ep == other.ep and self.key == other.key + + def __init__(self): + self.ev_to_results = OrderedDict() + + def merge(self, tuning_results: List[TuningResults]): + for trs in tuning_results: + self._merge_one(trs) + + def get_merged(self): + tuning_results = [] + for ev, flat_results in self.ev_to_results.items(): + results = {} + trs = { + "ep": ev.ep, + "validators": ev.validators, + "results": results, + } + for (op_sig, params_sig), kernel_id in flat_results.items(): + kernel_map = results.setdefault(op_sig, {}) + kernel_map[params_sig] = kernel_id + tuning_results.append(trs) + return tuning_results + + def _merge_one(self, trs: TuningResults): + ev = Merger.EpAndValidators(trs["ep"], trs["validators"]) + flat_results = self.ev_to_results.setdefault(ev, {}) + for op_sig, kernel_map in trs["results"].items(): + for params_sig, kernel_id in kernel_map.items(): + if (op_sig, params_sig) not in flat_results: + flat_results[(op_sig, params_sig)] = kernel_id + + +def parse_args(): + parser = argparse.ArgumentParser() + sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd") + + extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.") + extract_parser.add_argument("input_onnx") + extract_parser.add_argument("output_json") + + embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.") + embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.") + embed_parser.add_argument("output_onnx", help="Path of the output onnx file.") + embed_parser.add_argument("input_onnx", help="Path of the input onnx file.") + embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.") + + merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.") + merge_parser.add_argument("output_json", help="Path of the output tuning results file.") + merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.") + + pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.") + pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.") + + args = parser.parse_args() + if len(vars(args)) == 0: + parser.print_help() + exit(-1) + return args + + +if __name__ == "__main__": + args = parse_args() + if args.cmd == "extract": + tuning_results = extract(onnx.load_model(args.input_onnx)) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + exit(-1) + json.dump(tuning_results, open(args.output_json, "w")) + elif args.cmd == "embed": + model = onnx.load_model(args.input_onnx) + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + model = embed(model, merger.get_merged(), args.force) + onnx.save_model(model, args.output_onnx) + elif args.cmd == "merge": + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + json.dump(merger.get_merged(), open(args.output_json, "w")) + elif args.cmd == "pprint": + tuning_results = None + try: + tuning_results = json.load(open(args.json_or_onnx, "r")) + except: + pass + + if tuning_results is None: + try: + model = onnx.load_model(args.json_or_onnx) + tuning_results = extract(model) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + exit(-1) + except: + pass + + if tuning_results is None: + sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") + exit(-1) + + pprint(tuning_results) + else: + # invalid choice will be handled by the parser + pass From 6cb4f663e8bc267738f69958a6489c45929b1080 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 8 Feb 2023 07:26:54 +0000 Subject: [PATCH 04/18] Address comment --- onnxruntime/core/framework/tunable.h | 2 +- onnxruntime/core/framework/tuning_context.h | 5 ++++- .../core/framework/tuning_context_impl.h | 21 +++++++++---------- .../cuda/tunable/cuda_tuning_context.cc | 4 ++-- .../rocm/tunable/rocm_tuning_context.cc | 8 +++---- onnxruntime/core/session/inference_session.cc | 2 +- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index e36fb85266388..50442268e797e 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -140,7 +140,7 @@ class TunableOp { auto params_sig = params->Signature(); id = mgr.Lookup(op_sig, params_sig); if (id > static_cast(ops_.size())) { - LOGS_DEFAULT(FATAL) << "Invalid TunableOp kernel id for " << op_sig + LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig << ", id:" << id << ", registered op:" << ops_.size(); mgr.Delete(op_sig, params_sig); id = -1; diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index b32f7f8b324a1..81e357aaecb3a 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -29,7 +29,7 @@ class ITuningContext { virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; - virtual TuningResults SaveTuningResults() const; + virtual TuningResults GetTuningResults() const; virtual Status LoadTuningResults(const TuningResults& tr); protected: @@ -83,6 +83,9 @@ class TuningResultsValidator { virtual std::string GetOrtBuildConfig() const; virtual Status ValidateOrtBuildConfig(const std::string& value) const; + public: + static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; + private: GetValidateFuncs validators_; }; diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 2de458765c703..c8b0583e3ea5f 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -21,7 +21,7 @@ namespace onnxruntime { -TuningResults ITuningContext::SaveTuningResults() const { +TuningResults ITuningContext::GetTuningResults() const { TuningResults tr; tr.ep = ep_->Type(); tr.validators = GetTuningResultsValidator().GetAllValidators(); @@ -76,6 +76,7 @@ inline void AddImpl(const std::string& op_signature, return; } + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id; kernel_map[params_signature] = best_id; } @@ -103,6 +104,8 @@ void TuningResultsManager::Delete(const std::string& op_signature, const std::st if (it2 == it->second.end()) { return; } + + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")"; it->second.erase(it2); } @@ -117,6 +120,9 @@ void DisjointMergeImpl( /*out*/ std::unordered_map& results) { auto it = results.find(op_signature); if (it == results.end()) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id; + } results[op_signature] = kernel_map; return; } @@ -127,11 +133,6 @@ void DisjointMergeImpl( } void TuningResultsManager::Load(const std::unordered_map& results_to_load) { - for(const auto& [op_sig, kernel_map]: results_to_load) { - for(const auto& [param_sig, kernel_id] : kernel_map) { - LOGS_DEFAULT(VERBOSE) << op_sig << " " << param_sig << " " << kernel_id; - } - } std::scoped_lock l{lock_}; for (const auto& [op_signature, kernel_map] : results_to_load) { DisjointMergeImpl(op_signature, kernel_map, results_); @@ -147,14 +148,12 @@ void TuningResultsManager::Clear() { results_ = {}; } -Status CheckMandatoryKeys( +static Status CheckMandatoryKeys( const TuningResultsValidator::GetValidateFuncs& gv_funcs, const std::unordered_map& to_check) { - constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"}; - bool passed = true; std::ostringstream oss; - for (const auto& k : mandatory_keys) { + for (const auto& k : TuningResultsValidator::mandatory_keys) { if (gv_funcs.find(k) == gv_funcs.end()) { passed = false; oss << "key=\"" << k << "\" is not registered for Get and Validate. "; @@ -169,7 +168,7 @@ Status CheckMandatoryKeys( return Status::OK(); } -Status CheckKeysMatching( +static Status CheckKeysMatching( const TuningResultsValidator::GetValidateFuncs& gv_funcs, const std::unordered_map& to_check) { auto get_keys = [](const auto& it) -> std::string { return it.first; }; diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index e2a4bd694dd49..bea2889ca8834 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -14,13 +14,13 @@ namespace onnxruntime { namespace cuda { namespace tunable { -std::string GetCudaVersion() { +static std::string GetCudaVersion() { int version; CUDA_CALL_THROW(cudaRuntimeGetVersion(&version)); return std::to_string(version); } -Status ValidateCudaVersion(const std::string& value) { +static Status ValidateCudaVersion(const std::string& value) { auto current = GetCudaVersion(); ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value, ", onnxruntime currently run with CUDA ", current); diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index b2a8134c708c5..bead003e83679 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -14,27 +14,27 @@ namespace onnxruntime { namespace rocm { namespace tunable { -std::string GetHipVersion() { +static std::string GetHipVersion() { int version; HIP_CALL_THROW(hipRuntimeGetVersion(&version)); return std::to_string(version); } -Status ValidateHipVersion(const std::string& value) { +static Status ValidateHipVersion(const std::string& value) { auto current = GetHipVersion(); ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value, ", onnxruntime currently run with HIP ", current); return Status::OK(); } -std::string GetRocBlasVersion() { +static std::string GetRocBlasVersion() { char buf[64]; ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256)); buf[63] = '\0'; return buf; } -Status ValidateRocBlasVersion(const std::string& value) { +static Status ValidateRocBlasVersion(const std::string& value) { auto current = GetRocBlasVersion(); ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value, ", onnxruntime currently run with rocblas ", current); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d4674a717b8f6..3be978001c86f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2198,7 +2198,7 @@ std::vector InferenceSession::GetTuningResults() const { for (const auto& provider : execution_providers_) { const auto* tuning_ctx = provider->GetTuningContext(); if (tuning_ctx != nullptr) { - ret.emplace_back(tuning_ctx->SaveTuningResults()); + ret.emplace_back(tuning_ctx->GetTuningResults()); } } return ret; From 1ba5c5dcb37582bf51c4052af8197d007b78d399 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 9 Feb 2023 06:57:54 +0000 Subject: [PATCH 05/18] Minor update InferenceSession API and logging behavior --- onnxruntime/core/session/inference_session.cc | 16 +++++++--------- .../core/session/inference_session_utils.cc | 5 ++++- .../core/session/inference_session_utils.h | 3 ++- .../python/onnxruntime_inference_collection.py | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3be978001c86f..a74ee4f2d104a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1551,14 +1551,12 @@ common::Status InferenceSession::Initialize() { } std::vector tuning_results; - ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata(model_metadata_, tuning_results)); - if(!tuning_results.empty()) { + bool found_tuning_results = false; + ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata( + model_metadata_, tuning_results, found_tuning_results)); + if (found_tuning_results) { ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results)); } - else { - LOGS(*session_logger_, WARNING) << "Got empty tuning results."; - } - return status; } @@ -2212,24 +2210,24 @@ Status InferenceSession::SetTuningResults(const std::vector& trs, auto* provider = execution_providers_.Get(tr.ep); if (provider == nullptr) { msg = MakeString("Cannot find execution provider ", tr.ep); - LOGS(*session_logger_, WARNING) << msg; ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; continue; } auto* tuning_ctx = provider->GetTuningContext(); if (tuning_ctx == nullptr) { msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp."); - LOGS(*session_logger_, WARNING) << msg; ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; continue; } auto status = tuning_ctx->LoadTuningResults(tr); if (!status.IsOK()) { msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); - LOGS(*session_logger_, WARNING) << msg; ORT_RETURN_IF(error_on_invalid, msg); + LOGS(*session_logger_, WARNING) << msg; } } return Status::OK(); diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 5033f1ed90bef..9ba6c946c6d36 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -235,13 +235,16 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options } Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, - std::vector& results) { + std::vector& results, + bool& key_found) { results.clear(); + key_found = false; auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); if (it == metadata.custom_metadata_map.end()) { return Status::OK(); } + key_found = true; LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; ORT_TRY { diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 3d884cc1b991d..a0bcdb9013bf0 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -59,7 +59,8 @@ class JsonConfigParser { }; Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, - std::vector& results); + /*out*/ std::vector& results, + /*out*/ bool& key_found); #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 5108f3245213a..0883c528c9f07 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -289,7 +289,7 @@ def run_with_iobinding(self, iobinding, run_options=None): def get_tuning_results(self): return self._sess.get_tuning_results() - def set_tuning_results(self, results, error_on_invalid=False): + def set_tuning_results(self, results, *, error_on_invalid=False): return self._sess.set_tuning_results(results, error_on_invalid) def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices): From e012b603164aaf473070a33118c1204169754e00 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 9 Feb 2023 06:59:57 +0000 Subject: [PATCH 06/18] Add tests for the API --- onnxruntime/test/framework/tunable_op_test.cc | 67 ++++++++++++++- .../test/python/onnxruntime_test_python.py | 85 ++++++++++++++++++- 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 5d837bd4e9f0c..ed95262e27321 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -20,9 +20,25 @@ namespace { // test on CPU and it does not use stream using StreamT = void*; +constexpr static const char* kTestKey = "THE_TEST_KEY"; +constexpr static const char* kValidTestValue = "THE_VALID_TEST_VALUE"; +constexpr static const char* kInvalidTestValue = "A_INVALID_TEST_VALUE"; + +static std::string GetTestValue() { + return kValidTestValue; +} + +static Status ValidateTestValue(const std::string& value) { + auto current = GetTestValue(); + ORT_RETURN_IF(current != value, "Only ", kValidTestValue, " is valid for key ", kTestKey); + return Status::OK(); +} + class TestTuningResultsValidator : public TuningResultsValidator { public: - TestTuningResultsValidator() = default; + TestTuningResultsValidator() { + RegisterValidator(kTestKey, GetTestValue, ValidateTestValue); + }; protected: std::string GetOrtBuildConfig() const override { @@ -602,12 +618,59 @@ TEST(TuningContext, TunableOpRespectTuningContext) { { ASSERT_EQ(mgr.Lookup(op.Signature()).size(), 0u); - // TunableOp(...), respect the existing entry + // TunableOp(...), respect the existing entry (manually loaded) if id in bound mgr.Add(op.Signature(), params.Signature(), tuning::TunableVecAddSelectFast::kSlowFullId); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(last_run, "SlowFull"); } + + last_run.clear(); + mgr.Clear(); + { + // TunableOp(...), must not respect the existing entry if id not in bound + // manually create an out of bound id + mgr.Add(op.Signature(), params.Signature(), 1000000); + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()) << "TunableOp should recover from an out of bound id"; + ASSERT_EQ(last_run, "FastFull"); + ASSERT_EQ(mgr.Lookup(op.Signature(), params.Signature()), tuning::TunableVecAddSelectFast::kFastFullId); + } +#endif +} + +TEST(TuningContext, GetAndLoadTuningResults) { +#ifdef ORT_NO_RTTI + GTEST_SKIP() << "TunableOp needs RTTI to work correctly"; +#else + constexpr const int a = 7500000; + constexpr const int b = 42; + int c{}; + tuning::VecAddParamsRecordLastRun params(&a, &b, &c, 1, 0); + std::string last_run; + params.last_run = &last_run; + + tuning::TunableVecAddSelectFast op{}; + auto* ctx = params.TuningContext(); + ctx->EnableTunableOp(); + + auto status = op(¶ms); + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(last_run, "FastFull"); + + auto trs = ctx->GetTuningResults(); + ASSERT_EQ(trs.ep, "TestEP"); + + ASSERT_EQ(trs.validators.size(), TestTuningResultsValidator::mandatory_keys.size() + 1); + for (const auto& key : TestTuningResultsValidator::mandatory_keys) { + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(key))); + } + ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(kTestKey))); + + ASSERT_EQ(trs.results.size(), 1); + ASSERT_THAT(trs.results, ::testing::Contains(::testing::Key(op.Signature()))); + ASSERT_THAT(trs.results[op.Signature()], ::testing::Contains(::testing::Key(params.Signature()))); + ASSERT_EQ(trs.results[op.Signature()][params.Signature()], tuning::TunableVecAddSelectFast::kFastFullId); #endif } diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 89fd90ad3a19e..b9bc593ac0c80 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. # pylint: disable=C0116,W0212,R1720,C0114 -# -*- coding: UTF-8 -*- +import copy import gc import os import platform @@ -387,6 +387,89 @@ def testSessionProviders(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) self.assertEqual(["CPUExecutionProvider"], sess.get_providers()) + def testGetAndSetTuningResults(self): + def getTuningResultsForEp(sess, ep): # without the outer list + tuning_results = sess.get_tuning_results() + self.assertGreaterEqual(len(tuning_results), 1) + tuning_results_for_this_ep = [t for t in tuning_results if t.get("ep") == ep] + self.assertEqual(len(tuning_results_for_this_ep), 1) + return tuning_results_for_this_ep[0] + + probe_op_sig = "probe_but_not_an_op_signature" + probe_params_sig = "probe_but_not_an_params_signature" + probe_value = 10000000 + + def copyTuningResultsWithProbe(tr): + tr = copy.deepcopy(tr) + tr["results"][probe_op_sig] = {probe_params_sig: probe_value} + return tr + + def assertTuningResultsLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertIn(probe_op_sig, tr["results"]) + self.assertEqual(tr["results"][probe_op_sig], {probe_params_sig: probe_value}) + + def assertTuningResultsNotLoaded(sess, ep): + tr = getTuningResultsForEp(sess, ep) + self.assertNotIn(probe_op_sig, tr["results"]) + + def doTestGetAndSetTuningResults(ep): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=[ep]) + tuning_results = getTuningResultsForEp(sess, ep) + + self.assertIn("ep", tuning_results) + self.assertIn("results", tuning_results) + self.assertIn("validators", tuning_results) + self.assertIn("ORT_VERSION", tuning_results["validators"]) + self.assertNotIn("NOT_A_VALIDATOR_KEY", tuning_results["validators"]) + + # invalid EP will be rejected + invalid_unkonwn_ep = copyTuningResultsWithProbe(tuning_results) + invalid_unkonwn_ep["ep"] = "UnknownEP" + sess.set_tuning_results([invalid_unkonwn_ep]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([invalid_unkonwn_ep], error_on_invalid=True) + self.assertTrue("Cannot find execution provider UnknownEP" in str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + # missing validator key will be rejected + mismatched_validator_key_missing = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_missing["validators"].pop("ORT_VERSION") + sess.set_tuning_results([mismatched_validator_key_missing]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_missing], error_on_invalid=True) + self.assertTrue("ORT_VERSION" in str(context.exception)) + self.assertTrue("is not provided for validation" in str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + mismatched_validator_key_extra = copyTuningResultsWithProbe(tuning_results) + mismatched_validator_key_extra["validators"]["NOT_A_VALIDATOR_KEY"] = "NOT_USED" + sess.set_tuning_results([mismatched_validator_key_extra]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([mismatched_validator_key_extra], error_on_invalid=True) + self.assertTrue("NOT_A_VALIDATOR_KEY" in str(context.exception)) + self.assertTrue("is unable to consume it" in str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + validation_faliure = copyTuningResultsWithProbe(tuning_results) + validation_faliure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!" + sess.set_tuning_results([validation_faliure]) + with self.assertRaises(RuntimeError) as context: + sess.set_tuning_results([validation_faliure], error_on_invalid=True) + self.assertTrue("Failed to load TuningResults" in str(context.exception)) + self.assertTrue("version mismatch" in str(context.exception)) + assertTuningResultsNotLoaded(sess, ep) + + loadable = copyTuningResultsWithProbe(tuning_results) + sess.set_tuning_results([loadable], error_on_invalid=True) + assertTuningResultsLoaded(sess, ep) + + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("CUDAExecutionProvider") + + if "ROCMExecutionProvider" in onnxrt.get_available_providers(): + doTestGetAndSetTuningResults("ROCMExecutionProvider") + def testRunModel(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) From c27c991b0ca6c9994b5f8652aa5d732fc12756fb Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 03:37:14 +0000 Subject: [PATCH 07/18] Avoid the using of rocm_tunable.h --- onnxruntime/core/providers/rocm/rocm_execution_provider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 0178ac4f0b696..ac221d04b4f69 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -14,7 +14,7 @@ #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "core/providers/rocm/tunable/rocm_tuning_context.h" namespace onnxruntime { From 2a50e09c1b00dcbad0a566b2f7c69196524b348f Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 05:55:11 +0000 Subject: [PATCH 08/18] Remove unused var --- onnxruntime/test/framework/tunable_op_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index ed95262e27321..39dd452e10ee1 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -22,7 +22,6 @@ using StreamT = void*; constexpr static const char* kTestKey = "THE_TEST_KEY"; constexpr static const char* kValidTestValue = "THE_VALID_TEST_VALUE"; -constexpr static const char* kInvalidTestValue = "A_INVALID_TEST_VALUE"; static std::string GetTestValue() { return kValidTestValue; From 935d1ea1378a65ea594bf2d0e57fc3bcc0ac4814 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 06:08:49 +0000 Subject: [PATCH 09/18] Fix minimal build --- onnxruntime/core/session/inference_session.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a74ee4f2d104a..cd25b631f5217 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1550,6 +1550,7 @@ common::Status InferenceSession::Initialize() { } } +#if !defined(ORT_MINIMAL_BUILD) std::vector tuning_results; bool found_tuning_results = false; ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata( @@ -1557,6 +1558,7 @@ common::Status InferenceSession::Initialize() { if (found_tuning_results) { ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results)); } +#endif // !defined(ORT_MINIMAL_BUILD) return status; } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e4118e73a6164..efe295b4e0cdd 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1666,6 +1666,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { +#if !defined(ORT_MINIMAL_BUILD) py::list ret; for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) { py::dict py_trs; @@ -1676,8 +1677,13 @@ including arg name, arg type (contains both type and shape).)pbdoc") } return ret; +#else + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); +#endif }) .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { +#if !defined(ORT_MINIMAL_BUILD) std::vector tuning_results; for (auto handle: results) { auto py_trs = handle.cast(); @@ -1703,6 +1709,12 @@ including arg name, arg type (contains both type and shape).)pbdoc") if (!status.IsOK()) { throw std::runtime_error("Error in execution: " + status.ErrorMessage()); } +#else + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); +#endif }); py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) From 6f6960aae00899bc3d5196dec69eef717348b672 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 06:20:59 +0000 Subject: [PATCH 10/18] Address CodeQL check notice --- .../test/python/onnxruntime_test_python.py | 14 +++++----- tools/python/offline_tuning.py | 26 +++++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index b9bc593ac0c80..7e4dd478d2e30 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -429,7 +429,7 @@ def doTestGetAndSetTuningResults(ep): sess.set_tuning_results([invalid_unkonwn_ep]) with self.assertRaises(RuntimeError) as context: sess.set_tuning_results([invalid_unkonwn_ep], error_on_invalid=True) - self.assertTrue("Cannot find execution provider UnknownEP" in str(context.exception)) + self.assertIn("Cannot find execution provider UnknownEP", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) # missing validator key will be rejected @@ -438,8 +438,8 @@ def doTestGetAndSetTuningResults(ep): sess.set_tuning_results([mismatched_validator_key_missing]) with self.assertRaises(RuntimeError) as context: sess.set_tuning_results([mismatched_validator_key_missing], error_on_invalid=True) - self.assertTrue("ORT_VERSION" in str(context.exception)) - self.assertTrue("is not provided for validation" in str(context.exception)) + self.assertIn("ORT_VERSION", str(context.exception)) + self.assertIn("is not provided for validation", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) mismatched_validator_key_extra = copyTuningResultsWithProbe(tuning_results) @@ -447,8 +447,8 @@ def doTestGetAndSetTuningResults(ep): sess.set_tuning_results([mismatched_validator_key_extra]) with self.assertRaises(RuntimeError) as context: sess.set_tuning_results([mismatched_validator_key_extra], error_on_invalid=True) - self.assertTrue("NOT_A_VALIDATOR_KEY" in str(context.exception)) - self.assertTrue("is unable to consume it" in str(context.exception)) + self.assertIn("NOT_A_VALIDATOR_KEY", str(context.exception)) + self.assertIn("is unable to consume it", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) validation_faliure = copyTuningResultsWithProbe(tuning_results) @@ -456,8 +456,8 @@ def doTestGetAndSetTuningResults(ep): sess.set_tuning_results([validation_faliure]) with self.assertRaises(RuntimeError) as context: sess.set_tuning_results([validation_faliure], error_on_invalid=True) - self.assertTrue("Failed to load TuningResults" in str(context.exception)) - self.assertTrue("version mismatch" in str(context.exception)) + self.assertIn("Failed to load TuningResults", str(context.exception)) + self.assertIn("version mismatch", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) loadable = copyTuningResultsWithProbe(tuning_results) diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py index c2349b4706fb2..69bcdad382877 100644 --- a/tools/python/offline_tuning.py +++ b/tools/python/offline_tuning.py @@ -3,7 +3,6 @@ import argparse import copy -import itertools import json import sys from collections import OrderedDict @@ -14,22 +13,22 @@ TuningResults = Dict[str, Any] -_tuning_results_key = "tuning_results" +_TUNING_RESULTS_KEY = "tuning_results" def _find_tuning_results_in_props(metadata_props): for idx, prop in enumerate(metadata_props): - if prop.key == _tuning_results_key: + if prop.key == _TUNING_RESULTS_KEY: return idx return -1 -def extract(onnx: onnx.ModelProto): - idx = _find_tuning_results_in_props(onnx.metadata_props) +def extract(model: onnx.ModelProto): + idx = _find_tuning_results_in_props(model.metadata_props) if idx < 0: return None - tuning_results_prop = onnx.metadata_props[idx] + tuning_results_prop = model.metadata_props[idx] return json.loads(tuning_results_prop.value) @@ -41,7 +40,7 @@ def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite model.metadata_props.pop(idx) entry = model.metadata_props.add() - entry.key = _tuning_results_key + entry.key = _TUNING_RESULTS_KEY entry.value = json.dumps(tuning_results) return model @@ -118,13 +117,13 @@ def parse_args(): return args -if __name__ == "__main__": +def main(): args = parse_args() if args.cmd == "extract": tuning_results = extract(onnx.load_model(args.input_onnx)) if tuning_results is None: sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") - exit(-1) + sys.exit(-1) json.dump(tuning_results, open(args.output_json, "w")) elif args.cmd == "embed": model = onnx.load_model(args.input_onnx) @@ -143,6 +142,7 @@ def parse_args(): try: tuning_results = json.load(open(args.json_or_onnx, "r")) except: + # it might be an onnx file otherwise, try it latter pass if tuning_results is None: @@ -151,15 +151,19 @@ def parse_args(): tuning_results = extract(model) if tuning_results is None: sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") - exit(-1) + sys.exit(-1) except: pass if tuning_results is None: sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") - exit(-1) + sys.exit(-1) pprint(tuning_results) else: # invalid choice will be handled by the parser pass + + +if __name__ == "__main__": + main() From c5fbe228b971e0cb0d2aa5d9dde5ce3bb564dca2 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 06:27:05 +0000 Subject: [PATCH 11/18] Fix typo --- onnxruntime/test/python/onnxruntime_test_python.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 7e4dd478d2e30..ce57c3ac16d00 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -424,11 +424,11 @@ def doTestGetAndSetTuningResults(ep): self.assertNotIn("NOT_A_VALIDATOR_KEY", tuning_results["validators"]) # invalid EP will be rejected - invalid_unkonwn_ep = copyTuningResultsWithProbe(tuning_results) - invalid_unkonwn_ep["ep"] = "UnknownEP" - sess.set_tuning_results([invalid_unkonwn_ep]) + invalid_unknown_ep = copyTuningResultsWithProbe(tuning_results) + invalid_unknown_ep["ep"] = "UnknownEP" + sess.set_tuning_results([invalid_unknown_ep]) with self.assertRaises(RuntimeError) as context: - sess.set_tuning_results([invalid_unkonwn_ep], error_on_invalid=True) + sess.set_tuning_results([invalid_unknown_ep], error_on_invalid=True) self.assertIn("Cannot find execution provider UnknownEP", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) From ed3680055b4e41a5c609b5686eee2ce2dc63b75e Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 06:41:05 +0000 Subject: [PATCH 12/18] Workaround flake8 --- tools/python/offline_tuning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py index 69bcdad382877..8dbae5efe8f9f 100644 --- a/tools/python/offline_tuning.py +++ b/tools/python/offline_tuning.py @@ -141,7 +141,7 @@ def main(): tuning_results = None try: tuning_results = json.load(open(args.json_or_onnx, "r")) - except: + except Exception: # it might be an onnx file otherwise, try it latter pass @@ -152,7 +152,7 @@ def main(): if tuning_results is None: sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") sys.exit(-1) - except: + except Exception: pass if tuning_results is None: From 946e8e604c9873a8b4d60e9b3b8279c01bfa2fe3 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 07:28:42 +0000 Subject: [PATCH 13/18] Workaround warning treated as error --- onnxruntime/test/framework/tunable_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 39dd452e10ee1..d68882f0a02fa 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -666,7 +666,7 @@ TEST(TuningContext, GetAndLoadTuningResults) { } ASSERT_THAT(trs.validators, ::testing::Contains(::testing::Key(kTestKey))); - ASSERT_EQ(trs.results.size(), 1); + ASSERT_EQ(trs.results.size(), 1u); ASSERT_THAT(trs.results, ::testing::Contains(::testing::Key(op.Signature()))); ASSERT_THAT(trs.results[op.Signature()], ::testing::Contains(::testing::Key(params.Signature()))); ASSERT_EQ(trs.results[op.Signature()][params.Signature()], tuning::TunableVecAddSelectFast::kFastFullId); From 733125cc14e9c5d89413c9d17cebfc98751f500c Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 08:14:32 +0000 Subject: [PATCH 14/18] Fix loop variable '...' is always a copy... --- onnxruntime/python/onnxruntime_pybind_state.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index efe295b4e0cdd..de14b9aff368f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1690,15 +1690,15 @@ including arg name, arg type (contains both type and shape).)pbdoc") TuningResults trs; trs.ep = py_trs["ep"].cast(); - for (const auto& [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { + for (const auto [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { KernelMap kernel_map; - for (const auto& [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { + for (const auto [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { kernel_map[py_params_sig.cast()] = py_kernel_id.cast(); } trs.results[py_op_sig.cast()] = kernel_map; } - for (const auto& [k, v]: py_trs["validators"].cast()) { + for (const auto [k, v]: py_trs["validators"].cast()) { trs.validators[k.cast()] = v.cast(); } From 69e3ab119566f065f8768532f018fdd12bf3a805 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 13 Feb 2023 09:51:33 +0000 Subject: [PATCH 15/18] Add virtual dtor for TuningResultsValidator --- onnxruntime/core/framework/tuning_context.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 81e357aaecb3a..6cd61931b8aaf 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -67,6 +67,7 @@ class TuningResultsValidator { using GetValidateFuncs = std::unordered_map>; TuningResultsValidator(); + virtual ~TuningResultsValidator() = default; std::unordered_map GetAllValidators() const; Status ValidateAll(const std::unordered_map& to_validate) const; From 0d54fa11a1bb29097c6ad4643b41473c62c54429 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 14 Feb 2023 01:11:12 +0000 Subject: [PATCH 16/18] Move embedded tuning result loading logic into try catch scope --- onnxruntime/core/session/inference_session.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cd25b631f5217..79766bb7b8c9b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1500,6 +1500,14 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); } } + + std::vector tuning_results; + bool found_tuning_results = false; + ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( + model_metadata_, tuning_results, found_tuning_results)); + if (found_tuning_results) { + ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results)); + } #endif // !defined(ORT_MINIMAL_BUILD) // Resolve memory pattern flags of the main graph and subgraph session states @@ -1550,16 +1558,6 @@ common::Status InferenceSession::Initialize() { } } -#if !defined(ORT_MINIMAL_BUILD) - std::vector tuning_results; - bool found_tuning_results = false; - ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata( - model_metadata_, tuning_results, found_tuning_results)); - if (found_tuning_results) { - ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results)); - } -#endif // !defined(ORT_MINIMAL_BUILD) - return status; } #if defined(_MSC_VER) && !defined(__clang__) From 9605ddb3c16399981b255706ea9172254f0b0fd1 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 15 Feb 2023 00:46:42 +0000 Subject: [PATCH 17/18] Fix typo --- onnxruntime/test/python/onnxruntime_test_python.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index ce57c3ac16d00..8232044a29a59 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -451,11 +451,11 @@ def doTestGetAndSetTuningResults(ep): self.assertIn("is unable to consume it", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) - validation_faliure = copyTuningResultsWithProbe(tuning_results) - validation_faliure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!" - sess.set_tuning_results([validation_faliure]) + validation_failure = copyTuningResultsWithProbe(tuning_results) + validation_failure["validators"]["ORT_VERSION"] = "This is not a proper ORT_VERSION value!" + sess.set_tuning_results([validation_failure]) with self.assertRaises(RuntimeError) as context: - sess.set_tuning_results([validation_faliure], error_on_invalid=True) + sess.set_tuning_results([validation_failure], error_on_invalid=True) self.assertIn("Failed to load TuningResults", str(context.exception)) self.assertIn("version mismatch", str(context.exception)) assertTuningResultsNotLoaded(sess, ep) From 36b39229ad578eddecbf68e19438eb498ec3ce78 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 15 Feb 2023 00:57:22 +0000 Subject: [PATCH 18/18] Use ORT_HANDLE_EXCEPTION --- onnxruntime/core/session/inference_session_utils.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 9ba6c946c6d36..3e2d03a930f31 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -247,14 +247,18 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met key_found = true; LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + Status status; ORT_TRY { auto parsed_tuning_results_json = json::parse(it->second); results = parsed_tuning_results_json.get>(); } ORT_CATCH(const std::exception& e) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, FAIL, - "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, + "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + }); + ORT_RETURN_IF_ERROR(status); } return Status::OK();