Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline tuning #14558

Merged
merged 18 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions onnxruntime/core/framework/tunable.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,20 @@ class TunableOp {
ITuningContext* ctx = params->TuningContext();
if (ctx->IsTunableOpEnabled()) {
auto& mgr = ctx->GetTuningResultsManager();
id = mgr.Lookup(Signature(), params->Signature());
auto op_sig = Signature();
auto params_sig = params->Signature();
id = mgr.Lookup(op_sig, params_sig);
if (id > static_cast<int>(ops_.size())) {
LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig
<< ", id:" << id << ", registered op:" << ops_.size();
mgr.Delete(op_sig, params_sig);
id = -1;
}
if (id < 0) {
auto maybe_proxy_params = PreTuning(params);
id = FindFastest(maybe_proxy_params);
PostTuning(maybe_proxy_params);
mgr.Add(Signature(), params->Signature(), id);
mgr.Add(op_sig, params_sig, id);
}
}
ORT_RETURN_IF_ERROR(ops_[id](params));
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/core/framework/tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TuningResultsValidator;

class ITuningContext {
public:
explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {}
virtual ~ITuningContext() = default;

virtual void EnableTunableOp() = 0;
Expand All @@ -25,6 +26,14 @@ class ITuningContext {

virtual TuningResultsManager& GetTuningResultsManager() = 0;
virtual const TuningResultsManager& GetTuningResultsManager() const = 0;

virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0;

virtual TuningResults GetTuningResults() const;
virtual Status LoadTuningResults(const TuningResults& tr);

protected:
IExecutionProvider* ep_;
};

class TuningResultsManager {
Expand All @@ -36,6 +45,7 @@ class TuningResultsManager {
int Lookup(const std::string& op_signature, const std::string& params_signature) const;

void Add(const std::string& op_signature, const std::string& params_signature, int best_id);
void Delete(const std::string& op_signature, const std::string& params_signature);

void Load(const std::unordered_map<std::string, KernelMap>& results_to_load);
std::unordered_map<std::string, KernelMap> Dump() const;
Expand All @@ -50,4 +60,34 @@ class TuningResultsManager {
std::unordered_map<std::string, KernelMap> results_;
};

class TuningResultsValidator {
public:
using GetFunc = std::function<std::string()>;
using ValidateFunc = std::function<Status(const std::string&)>;
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;

TuningResultsValidator();

std::unordered_map<std::string, std::string> GetAllValidators() const;
Status ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;

protected:
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);

virtual std::string GetOrtVersion() const;
virtual Status ValidateOrtVersion(const std::string& value) const;

virtual std::string GetOrtGitCommit() const;
virtual Status ValidateOrtGitCommit(const std::string& value) const;

virtual std::string GetOrtBuildConfig() const;
virtual Status ValidateOrtBuildConfig(const std::string& value) const;

public:
static constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"};

private:
GetValidateFuncs validators_;
};

} // namespace onnxruntime
178 changes: 177 additions & 1 deletion onnxruntime/core/framework/tuning_context_impl.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these
// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these
// methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries
// are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem,
// the EP must has and only has one translation unit include this file.
Expand All @@ -11,12 +11,32 @@

#pragma once

#include <functional>
#include <unordered_set>
#include <utility>

#include "core/framework/tunable.h"
#include "core/framework/tuning_context.h"
#include "core/framework/tuning_results.h"

namespace onnxruntime {

TuningResults ITuningContext::GetTuningResults() const {
TuningResults tr;
tr.ep = ep_->Type();
tr.validators = GetTuningResultsValidator().GetAllValidators();
tr.results = GetTuningResultsManager().Dump();
return tr;
}

Status ITuningContext::LoadTuningResults(const TuningResults& tr) {
ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch");
LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep;
ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators));
GetTuningResultsManager().Load(tr.results);
return Status::OK();
}

KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const {
std::scoped_lock l{lock_};
auto it = results_.find(op_signature);
Expand Down Expand Up @@ -56,6 +76,7 @@ inline void AddImpl(const std::string& op_signature,
return;
}

LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ") -> " << best_id;
kernel_map[params_signature] = best_id;
}

Expand All @@ -70,6 +91,24 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin
AddImpl(op_signature, params_signature, best_id, it->second);
}

// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};

auto it = results_.find(op_signature);
if (it == results_.end()) {
return;
}

auto it2 = it->second.find(params_signature);
if (it2 == it->second.end()) {
return;
}

LOGS_DEFAULT(VERBOSE) << op_signature << "(" << params_signature << ")";
it->second.erase(it2);
}

std::unordered_map<std::string, KernelMap> TuningResultsManager::Dump() const {
std::scoped_lock l{lock_};
return results_;
Expand All @@ -81,6 +120,9 @@ void DisjointMergeImpl(
/*out*/ std::unordered_map<std::string, KernelMap>& results) {
auto it = results.find(op_signature);
if (it == results.end()) {
for(const auto& [param_sig, kernel_id] : kernel_map) {
LOGS_DEFAULT(VERBOSE) << op_signature << "(" << param_sig << ") -> " << kernel_id;
}
results[op_signature] = kernel_map;
return;
}
Expand All @@ -106,4 +148,138 @@ void TuningResultsManager::Clear() {
results_ = {};
}

static Status CheckMandatoryKeys(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
const std::unordered_map<std::string, std::string>& to_check) {
bool passed = true;
std::ostringstream oss;
for (const auto& k : TuningResultsValidator::mandatory_keys) {
if (gv_funcs.find(k) == gv_funcs.end()) {
passed = false;
oss << "key=\"" << k << "\" is not registered for Get and Validate. ";
}

if (to_check.find(k) == to_check.end()) {
passed = false;
oss << "key=\"" << k << "\" is not provided for validation. ";
}
}
ORT_RETURN_IF(!passed, oss.str());
return Status::OK();
}

static Status CheckKeysMatching(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
cloudhan marked this conversation as resolved.
Show resolved Hide resolved
const std::unordered_map<std::string, std::string>& to_check) {
auto get_keys = [](const auto& it) -> std::string { return it.first; };
std::vector<std::string> required_keys;
std::vector<std::string> provided_keys;
std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys);
std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys);
std::sort(required_keys.begin(), required_keys.end());
std::sort(provided_keys.begin(), provided_keys.end());

std::unordered_set<std::string> intersection;
std::set_intersection(required_keys.cbegin(), required_keys.cend(),
provided_keys.cbegin(), provided_keys.cend(),
std::inserter(intersection, intersection.end()));
bool matched = true;
std::ostringstream oss;
if (intersection.size() != required_keys.size()) {
matched = false;
for (const auto& k : required_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. ";
}
}
}
if (intersection.size() != provided_keys.size()) {
matched = false;
for (const auto& k : provided_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. ";
}
}
}
ORT_RETURN_IF(!matched, oss.str());
return Status::OK();
}

std::string TuningResultsValidator::GetOrtVersion() const {
return ORT_VERSION;
}

Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const {
ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch");
return Status::OK();
}

std::string TuningResultsValidator::GetOrtGitCommit() const {
// TODO:
return "";
}

Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const {
// TODO:
ORT_UNUSED_PARAMETER(value);
return Status::OK();
}

std::string TuningResultsValidator::GetOrtBuildConfig() const {
return "";
abudup marked this conversation as resolved.
Show resolved Hide resolved
}

Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const {
auto current = GetOrtBuildConfig();
ORT_RETURN_IF(current != value,
"onnxruntime building configuration mismatch: tuning results produced with library \"",
value, "\", current library built with \"", current, "\"");
return Status::OK();
}

TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"ORT_VERSION",
[this]() { return GetOrtVersion(); },
[this](auto&& k) { return ValidateOrtVersion(std::forward<decltype(k)>(k)); });

RegisterValidator(
"ORT_GIT_COMMIT",
[this]() { return GetOrtGitCommit(); },
[this](auto&& k) { return ValidateOrtGitCommit(std::forward<decltype(k)>(k)); });

RegisterValidator(
"ORT_BUILD_CONFIG",
[this]() { return GetOrtBuildConfig(); },
[this](auto&& k) { return ValidateOrtBuildConfig(std::forward<decltype(k)>(k)); });
}

Status TuningResultsValidator::ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const {
ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate));
ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate));

for (const auto& [key, value] : to_validate) {
const auto& it = validators_.find(key);
ORT_ENFORCE(it != validators_.cend());
const ValidateFunc& validator = it->second.second;
ORT_RETURN_IF_ERROR(validator(value));
}

return Status::OK();
}

std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TuningResultsValidator::GetAllValidators()

The naming is a bit misleading: you're not getting the validators themselves, you're only getting the values to validate. Perhaps rename to GetAllKVPairsToValidate instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is weird. It is actually the function/closure to be serialized, but it is impossible to do so in C++. So we just serialized the data the validator function needs.

Also in tuning_results.h TuningResults have a member called validators, maybe we also need a new name for it.

std::unordered_map<std::string, std::string> ret;
for (const auto& [key, get_validate_func_pair] : validators_) {
const GetFunc& getter = get_validate_func_pair.first;
ret[key] = getter();
}
return ret;
}

void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) {
ORT_ENFORCE(validators_.find(key) == validators_.end());
validators_[key] = std::make_pair(gf, vf);
}

} // namespace onnxruntime
39 changes: 38 additions & 1 deletion onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,40 @@ namespace onnxruntime {
namespace cuda {
namespace tunable {

CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {}
static std::string GetCudaVersion() {
int version;
CUDA_CALL_THROW(cudaRuntimeGetVersion(&version));
return std::to_string(version);
}

static Status ValidateCudaVersion(const std::string& value) {
auto current = GetCudaVersion();
ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value,
", onnxruntime currently run with CUDA ", current);
return Status::OK();
}

std::string CudaTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}

Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}

CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) {
RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}

CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}

void CudaTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider";
Expand All @@ -38,6 +71,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const {
return manager_;
}

const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const {
return validator_;
}

} // namespace tunable
} // namespace cuda
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class CUDAExecutionProvider;
namespace cuda {
namespace tunable {

class CudaTuningResultsValidator : public TuningResultsValidator {
public:
CudaTuningResultsValidator(CUDAExecutionProvider* ep);

protected:
std::string GetDeviceModel() const;
Status ValidateDeviceModel(const std::string& value) const;

private:
CUDAExecutionProvider* ep_; // non-owning handle
};

class CudaTuningContext : public ITuningContext {
public:
explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info);
Expand All @@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;

const TuningResultsValidator& GetTuningResultsValidator() const override;

private:
TunableOpInfo* info_; // non-owning handle
TuningResultsManager manager_;
CudaTuningResultsValidator validator_;
};

} // namespace tunable
Expand Down
Loading