Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add QNN EP HTP shared memory allocator #23136

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
110a3bc
save work
edgchen1 Nov 5, 2024
0ba3a2f
save work
edgchen1 Nov 9, 2024
8436b14
add logging for setting QNN tensor memory, update comment
edgchen1 Nov 11, 2024
c9826f4
add option to enable HTP shared memory allocator to onnxruntime_perf_…
edgchen1 Nov 11, 2024
c07c35e
hack - try to cache mem handles in QnnModel
edgchen1 Nov 12, 2024
60dc837
Remove duplicate include.
edgchen1 Nov 13, 2024
24e072f
hack, continued - move cache out to SharedContext
edgchen1 Nov 14, 2024
e66cbef
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Nov 14, 2024
8c515da
move mem handle registration to allocator
edgchen1 Nov 15, 2024
18e2780
hook up some test code
edgchen1 Nov 15, 2024
09ddce5
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Nov 19, 2024
a65bb71
rename to RpcMemAllocator to HtpSharedMemoryAllocator
edgchen1 Nov 27, 2024
bfb135e
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 2, 2024
f179a0d
remove onnx protobuf dependency from allocator.h, add shared provider…
edgchen1 Dec 3, 2024
7645ef4
remove unused CPUAllocator::TensorAlloc declaration
edgchen1 Dec 5, 2024
1043732
Check for nullptr when trying to free
baijumeswani Dec 5, 2024
022f4bc
move mem handle management to QNN backend manager
edgchen1 Dec 10, 2024
c527dee
remove IAllocator::TensorAlloc()
edgchen1 Dec 10, 2024
e4f72b3
document IAllocator::Free
edgchen1 Dec 10, 2024
39ff901
remove IAllocator__TensorAlloc
edgchen1 Dec 10, 2024
1bed5a4
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 10, 2024
d70db84
fix android build warning
edgchen1 Dec 10, 2024
45ef883
remove shared mem handles from shared context
edgchen1 Dec 11, 2024
d2e7b3c
remove allocation clean up callback removal, use weak_ptrs in allocat…
edgchen1 Dec 16, 2024
c892c18
some clean up
edgchen1 Dec 17, 2024
b295eef
more clean up
edgchen1 Dec 17, 2024
13f5e30
add helper to get qnn error message
edgchen1 Dec 17, 2024
d5eace1
use make_shared for QnnBackendManager
edgchen1 Dec 17, 2024
bacbcdc
add test to qnn_basic_test.cc, document allocator parameter.
edgchen1 Dec 17, 2024
30cd9ed
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 17, 2024
b29ab61
rename variables
edgchen1 Dec 18, 2024
67a54b8
revert changes to onnxruntime/test/providers/qnn/max_min_op_test.cc
edgchen1 Dec 18, 2024
c0569e2
fix formatting
edgchen1 Dec 19, 2024
dd45c84
skip test if not android and not windows
edgchen1 Dec 19, 2024
959d8df
update comment
edgchen1 Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
constexpr const char* OpenVINO_RT = "OpenVINO_RT";
constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU";
constexpr const char* QNN_HTP_SHARED = "QnnHtpShared";
constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer";
constexpr const char* WEBNN_TENSOR = "WebNN_Tensor";

Expand Down Expand Up @@ -81,6 +82,10 @@ class IAllocator {
*/
virtual void* Alloc(size_t size) = 0;

/**
* Free memory at p.
* If p is nullptr, do nothing.
*/
virtual void Free(void* p) = 0;

// Reserve() is an interface exposed for an implementation of IAllocator
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct OrtDevice {
static const MemoryType CUDA_PINNED = 1;
static const MemoryType HIP_PINNED = 2;
static const MemoryType CANN_PINNED = 3;
static const MemoryType QNN_HTP_SHARED = 4;
};

constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortmemoryinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string_view>

#include "core/common/hash_combine.h"
#include "core/framework/ortdevice.h"

struct OrtMemoryInfo {
OrtMemoryInfo() = default; // to allow default construction of Tensor
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2130,10 +2130,10 @@ struct KernelContext {
explicit KernelContext(OrtKernelContext* context);
size_t GetInputCount() const;
size_t GetOutputCount() const;
// If input is optional and is not present, the method returns en empty ConstValue
// If input is optional and is not present, the method returns an empty ConstValue
// which can be compared to nullptr.
ConstValue GetInput(size_t index) const;
// If outout is optional and is not present, the method returns en empty UnownedValue
// If outout is optional and is not present, the method returns an empty UnownedValue
// which can be compared to nullptr.
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
onnxruntime::CUDA_PINNED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
onnxruntime::HIP_PINNED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::QNN_HTP_SHARED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ SessionState::SessionState(Graph& graph,
for (auto& ep : execution_providers_) {
auto allocators = ep->CreatePreferredAllocators();
for (auto& alloc : allocators) {
allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key
allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key
}
}
}
Expand Down
79 changes: 69 additions & 10 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@
#include <fstream>
#include <string>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
#include "DSP/QnnDspCommon.h"
#include "HTP/QnnHtpCommon.h"
#include "HTP/QnnHtpContext.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "Saver/QnnSaver.h"
#include <gsl/gsl>
#include "core/framework/endian_utils.h"
#include "core/common/logging/capture.h"
#include "core/providers/qnn/qnn_allocator.h"
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
#include "core/providers/qnn/builder/qnn_configs_helper.h"
#include "core/providers/qnn/builder/qnn_utils.h"

#ifdef _WIN32
#include <winmeta.h>
Expand Down Expand Up @@ -550,10 +552,11 @@
device_handle_,
context_configs,
&context);
contexts_.push_back(context);

ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result));

ORT_RETURN_IF_ERROR(AddQnnContext(context));

context_created_ = true;
return Status::OK();
}
Expand All @@ -563,6 +566,9 @@
return Status::OK();
}

// release context mem handles
context_mem_handles_.clear();

bool failed = false;
for (auto context : contexts_) {
Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr);
Expand Down Expand Up @@ -771,7 +777,7 @@
&context,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
ORT_RETURN_IF_ERROR(AddQnnContext(context));
if (1 == graph_count) {
// in case the EPContext node is generated from script
// the graph name from the context binary may not match the EPContext node name
Expand Down Expand Up @@ -1413,12 +1419,7 @@
}

const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) {
// From QNN SDK: The memory is statically owned and should not be freed by the caller.
const char* error_msg = nullptr;
if (QNN_SUCCESS == qnn_interface_.errorGetMessage(error, &error_msg)) {
return error_msg;
}
return "Unknown";
return utils::GetQnnErrorMessage(qnn_interface_, error);
}

const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) {
Expand Down Expand Up @@ -1651,5 +1652,63 @@
#endif
}

Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) {
ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set.");

auto mem_handle_manager = std::make_shared<QnnContextMemHandleManager>(GetQnnInterface(), context, *logger_);

Check warning on line 1658 in onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc:1658: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_manager)).second;
ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context);

contexts_.push_back(context);

return Status::OK();
}

Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address,
const Qnn_Tensor_t& qnn_tensor,
Qnn_MemHandle_t& mem_handle) {
const auto context_mem_handles_it = context_mem_handles_.find(context);
ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context);

auto& context_mem_handle_manager = context_mem_handles_it->second;
bool did_register{};
ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor,
mem_handle, did_register));

if (did_register) {
HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up =
[&logger = *logger_,
weak_backend_manager = weak_from_this(),
weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}](
void* shared_memory_address) {
// get QnnBackendManager shared_ptr to ensure that qnn_interface is still valid
auto backend_manager = weak_backend_manager.lock();
if (!backend_manager) {
return;
}

auto context_mem_handle_manager = weak_context_mem_handle_manager.lock();
if (!context_mem_handle_manager) {
return;
}

// TODO should also ensure that the QNN context handle is still valid.

Check warning on line 1695 in onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc:1695: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// This *should* be true as long as the QNN contexts are not freed from anywhere other than
// ~QnnBackendManager(). If we are able to lock weak_backend_manager, we haven't gotten to the dtor yet.

auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address);
if (!unregister_status.IsOK()) {
LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: "
<< shared_memory_address << ", error: " << unregister_status.ErrorMessage();
}
};

ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address,
std::move(allocation_clean_up)));

Check warning on line 1707 in onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc:1707: Add #include <utility> for move [build/include_what_you_use] [4]
}

return Status::OK();
}

} // namespace qnn
} // namespace onnxruntime
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/common/path_string.h"
#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h"
#include "core/providers/qnn/builder/qnn_def.h"

namespace onnxruntime {
namespace qnn {

class QnnModel;

class QnnBackendManager {
class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager> {
public:
QnnBackendManager(std::string&& backend_path,
ProfilingLevel profiling_level_etw,
Expand Down Expand Up @@ -170,6 +171,10 @@ class QnnBackendManager {
uint64_t buffer_length,
uint64_t& max_spill_fill_buffer_size);

Status GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address,
const Qnn_Tensor_t& qnn_tensor,
Qnn_MemHandle_t& mem_handle);

private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);

Expand Down Expand Up @@ -240,6 +245,9 @@ class QnnBackendManager {
const char* eventIdentifier);
#endif

Status AddQnnContext(Qnn_ContextHandle_t context);
Status ReleaseQnnContextMemHandles();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete old declaration


private:
const std::string backend_path_;
std::mutex logger_mutex_;
Expand All @@ -253,6 +261,9 @@ class QnnBackendManager {
Qnn_LogHandle_t log_handle_ = nullptr;
Qnn_DeviceHandle_t device_handle_ = nullptr;
std::vector<Qnn_ContextHandle_t> contexts_;
// Note: Using shared_ptr<QnnContextMemHandleManager> so that we can refer to it with a weak_ptr from a
// HtpSharedMemoryAllocator allocation cleanup callback.
std::unordered_map<Qnn_ContextHandle_t, std::shared_ptr<QnnContextMemHandleManager>> context_mem_handles_;
ProfilingLevel profiling_level_etw_;
ProfilingLevel profiling_level_;
ProfilingLevel profiling_level_merge_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h"

#include "HTP/QnnHtpMem.h"

#include "core/common/common.h"
#include "core/providers/qnn/builder/qnn_def.h"
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/qnn/qnn_allocator.h"

namespace onnxruntime::qnn {

QnnContextMemHandleManager::QnnContextMemHandleManager(const QNN_INTERFACE_VER_TYPE& qnn_interface,
Qnn_ContextHandle_t context,
const logging::Logger& logger)
: qnn_interface_{qnn_interface},
context_{context},
logger_{logger} {
}

QnnContextMemHandleManager::~QnnContextMemHandleManager() {
Clear();
}

Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor,
Qnn_MemHandle_t& qnn_mem_handle, bool& did_register) {
const auto qnn_tensor_rank = GetQnnTensorRank(qnn_tensor);
auto* const qnn_tensor_dims = GetQnnTensorDims(qnn_tensor);
const auto qnn_tensor_data_type = GetQnnTensorDataType(qnn_tensor);

const size_t qnn_tensor_data_size =
utils::GetQnnTensorDataSize(gsl::span{qnn_tensor_dims, size_t{qnn_tensor_rank}}, qnn_tensor_data_type);

{
std::scoped_lock g{mem_handles_mutex_};

// find existing mem handle
if (const auto mem_handles_it = mem_handles_.find(shared_memory_address);
mem_handles_it != mem_handles_.end()) {
const auto& mem_handle_record = mem_handles_it->second;

// check that actual tensor size is less than or equal to registered tensor size
ORT_RETURN_IF_NOT(qnn_tensor_data_size <= mem_handle_record.registered_tensor_data_size,
"Actual tensor data size (", qnn_tensor_data_size,
") is larger than registered tensor data size (", mem_handle_record.registered_tensor_data_size,
").");

qnn_mem_handle = mem_handle_record.mem_handle.get();
did_register = false;
return Status::OK();
}

// register a new mem handle
HtpSharedMemoryAllocator::SharedMemoryInfo shared_memory_info{};
ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(shared_memory_address,
shared_memory_info));

Qnn_MemDescriptor_t mem_descriptor{};
mem_descriptor.memShape.dimSize = qnn_tensor_dims;
mem_descriptor.memShape.numDim = qnn_tensor_rank;
mem_descriptor.memShape.shapeConfig = nullptr;
mem_descriptor.dataType = qnn_tensor_data_type;
mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM;

QnnMemHtp_Descriptor_t htp_mem_descriptor{};
htp_mem_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER;
htp_mem_descriptor.size = shared_memory_info.total_size;
htp_mem_descriptor.sharedBufferConfig.fd = shared_memory_info.fd;
htp_mem_descriptor.sharedBufferConfig.offset = shared_memory_info.offset;

mem_descriptor.customInfo = &htp_mem_descriptor;

LOGS(logger_, VERBOSE) << "Registering QNN mem handle for context: " << context_
<< ", shared memory (address: " << shared_memory_address
<< ", offset: " << shared_memory_info.offset
<< ", fd: " << shared_memory_info.fd
<< ")";

Qnn_MemHandle_t raw_mem_handle{};
const auto register_result = qnn_interface_.memRegister(context_, &mem_descriptor, 1, &raw_mem_handle);
ORT_RETURN_IF_NOT(register_result == QNN_SUCCESS,
"qnn_interface.memRegister() failed: ",
utils::GetVerboseQnnErrorMessage(qnn_interface_, register_result));

LOGS(logger_, VERBOSE) << "Registered QNN mem handle. mem_handle: " << raw_mem_handle;

const auto unregister_mem_handle = [this](Qnn_MemHandle_t raw_mem_handle) {
LOGS(logger_, VERBOSE) << "Unregistering QNN mem handle. mem_handle: " << raw_mem_handle;

const auto unregister_result = qnn_interface_.memDeRegister(&raw_mem_handle, 1);
if (unregister_result != QNN_SUCCESS) {
LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: "
<< utils::GetVerboseQnnErrorMessage(qnn_interface_, unregister_result);
}
};

UniqueQnnMemHandle mem_handle(raw_mem_handle, unregister_mem_handle);
MemHandleRecord mem_handle_record{qnn_tensor_data_size, std::move(mem_handle)};
mem_handles_.emplace(shared_memory_address, std::move(mem_handle_record));

Check warning on line 101 in onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc:101: Add #include <utility> for move [build/include_what_you_use] [4]

qnn_mem_handle = raw_mem_handle;
did_register = true;
return Status::OK();
}
}

Status QnnContextMemHandleManager::Unregister(void* shared_memory_address) {
std::scoped_lock g{mem_handles_mutex_};

auto mem_handles_it = mem_handles_.find(shared_memory_address);
ORT_RETURN_IF_NOT(mem_handles_it != mem_handles_.end(),
"No mem handle found for address (", shared_memory_address, ").");

mem_handles_.erase(mem_handles_it);

return Status::OK();
}

void QnnContextMemHandleManager::Clear() {
std::scoped_lock g{mem_handles_mutex_};
mem_handles_.clear();
}

} // namespace onnxruntime::qnn
Loading
Loading