Skip to content

Commit

Permalink
Fix delay load for WebGPU EP and DML EP
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Dec 15, 2024
1 parent 3a0b958 commit 7a84bf1
Show file tree
Hide file tree
Showing 14 changed files with 369 additions and 49 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ if(WIN32)
onnxruntime_add_shared_library(onnxruntime
${SYMBOL_FILE}
"${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc"
"${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc"
"${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc"
)
elseif(onnxruntime_BUILD_APPLE_FRAMEWORK)
Expand Down
4 changes: 4 additions & 0 deletions cmake/onnxruntime_nodejs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,8 @@ add_custom_target(nodejs_binding_wrapper ALL
add_dependencies(js_common_npm_ci js_npm_ci)
add_dependencies(nodejs_binding_wrapper js_common_npm_ci)
add_dependencies(nodejs_binding_wrapper onnxruntime)
if (WIN32 AND onnxruntime_USE_WEBGPU)
add_dependencies(nodejs_binding_wrapper copy_dxil_dll)
add_dependencies(nodejs_binding_wrapper dxcompiler)
endif()
endif()
12 changes: 12 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ set (onnxruntime_global_thread_pools_test_SRC
set (onnxruntime_webgpu_external_dawn_test_SRC
${TEST_SRC_DIR}/webgpu/external_dawn/main.cc)

set (onnxruntime_webgpu_delay_load_test_SRC
${TEST_SRC_DIR}/webgpu/delay_load/main.cc)

# tests from lowest level library up.
# the order of libraries should be maintained, with higher libraries being added first in the list

Expand Down Expand Up @@ -1864,4 +1867,13 @@ if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN)
onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers)
endif()

if (onnxruntime_USE_WEBGPU AND WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
AddTest(DYN
TARGET onnxruntime_webgpu_delay_load_test
SOURCES ${onnxruntime_webgpu_delay_load_test_SRC}
LIBS ${SYS_PATH_LIB}
DEPENDS ${all_dependencies}
)
endif()

include(onnxruntime_fuzz_test.cmake)
8 changes: 8 additions & 0 deletions js/node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ if (WIN32)
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll
DESTINATION ${dist_folder})
endif ()
if(USE_WEBGPU)
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/webgpu_dawn.dll
DESTINATION ${dist_folder})
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxil.dll
DESTINATION ${dist_folder})
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxcompiler.dll
DESTINATION ${dist_folder})
endif ()
elseif (APPLE)
file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib
DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN)
Expand Down
37 changes: 0 additions & 37 deletions js/node/src/directml_load_helper.cc

This file was deleted.

6 changes: 0 additions & 6 deletions js/node/src/directml_load_helper.h

This file was deleted.

4 changes: 0 additions & 4 deletions js/node/src/inference_session_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "onnxruntime_cxx_api.h"

#include "common.h"
#include "directml_load_helper.h"
#include "inference_session_wrap.h"
#include "run_options_helper.h"
#include "session_options_helper.h"
Expand All @@ -19,9 +18,6 @@ Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
}

Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
#if defined(USE_DML) && defined(_WIN32)
LoadDirectMLDll(env);
#endif
// create ONNX runtime env
Ort::InitApi();
ORT_NAPI_THROW_ERROR_IF(
Expand Down
91 changes: 91 additions & 0 deletions onnxruntime/core/dll/delay_load_hook.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// == workaround for delay loading of dependencies of onnxruntime.dll ==
//
// Problem:
//
// When onnxruntime.dll uses delay loading for its dependencies, the dependencies are loaded using LoadLibraryEx,
// which search the directory of process (.exe) instead of this library (onnxruntime.dll). This is a problem for
// usages of Node.js binding and python binding, because Windows will try to find the dependencies in the directory
// of node.exe or python.exe, which is not the directory of onnxruntime.dll.
//
// Solution:
//
// By using the delay load hook `__pfnDliNotifyHook2`, we can intervene the loading procedure by loading from an
// absolute path. The absolute path is constructed by appending the name of the DLL to load to the directory of
// onnxruntime.dll. This way, we can ensure that the dependencies are loaded from the same directory as onnxruntime.dll.
//
// See also:
// - https://learn.microsoft.com/en-us/cpp/build/reference/understanding-the-helper-function?view=msvc-170#structure-and-constant-definitions
// - https://learn.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#alternate-search-order-for-unpackaged-apps
//
// The DLL DelayLoad hook is only enabled when:
// - The compiler is MSVC
// - at least one of USE_WEBGPU or USE_DML is defined
//
#if defined(_MSC_VER) && (defined(USE_WEBGPU) || defined(USE_DML))

#include <Windows.h>
#include <delayimp.h>
#include <stdlib.h>
#include <string>

namespace {

#define DEFINE_KNOWN_DLL(name) {#name ".dll", L#name L".dll"}

constexpr struct {
const char* str;
const wchar_t* wstr;
} known_dlls[] = {
#if defined(USE_WEBGPU)
DEFINE_KNOWN_DLL(webgpu_dawn),
#endif
#if defined(USE_DML)
DEFINE_KNOWN_DLL(DirectML),
#endif
};
} // namespace

FARPROC WINAPI delay_load_hook(unsigned dliNotify, PDelayLoadInfo pdli) {
if (dliNotify == dliNotePreLoadLibrary) {
for (size_t i = 0; i < _countof(known_dlls); ++i) {
if (_stricmp(pdli->szDll, known_dlls[i].str) == 0) {
// Try to load the DLL from the same directory as onnxruntime.dll

// First, get the path to onnxruntime.dll
DWORD pathLen = MAX_PATH;
std::wstring path(pathLen, L'\0');
HMODULE moduleHandle = nullptr;

GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPCWSTR>(&delay_load_hook), &moduleHandle);

DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
int ret = GetLastError();
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
pathLen *= 2;
path.resize(pathLen);
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
} else {
// Failed to get the path to onnxruntime.dll. In this case, we will just return NULL and let the system
// search for the DLL in the default search order.
return NULL;
}
}

path.resize(path.rfind(L'\\') + 1);
path.append(known_dlls[i].wstr);

return FARPROC(LoadLibraryExW(path.c_str(), NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR));
}
}
}
return NULL;
}

extern "C" const PfnDliHook __pfnDliNotifyHook2 = delay_load_hook;

#endif
2 changes: 1 addition & 1 deletion onnxruntime/core/dll/dllmain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#pragma GCC diagnostic pop
#endif

// dllmain.cpp : Defines the entry point for the DLL application.
// dllmain.cc : Defines the entry point for the DLL application.
BOOL APIENTRY DllMain(HMODULE /*hModule*/,
DWORD ul_reason_for_call,
LPVOID /*lpReserved*/
Expand Down
87 changes: 87 additions & 0 deletions onnxruntime/core/providers/webgpu/dll_delay_load_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/dll_delay_load_helper.h"

#if defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)

#include <Windows.h>
#include <delayimp.h>
#include <stdlib.h>
#include <string>
#include <mutex>

namespace onnxruntime {
namespace webgpu {

namespace {

// Get the directory of the current DLL (usually it's onnxruntime.dll).
std::wstring GetCurrentDllDir() {
DWORD pathLen = MAX_PATH;
std::wstring path(pathLen, L'\0');
HMODULE moduleHandle = nullptr;

GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPCWSTR>(&GetCurrentDllDir), &moduleHandle);

DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
int ret = GetLastError();
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
pathLen *= 2;
path.resize(pathLen);
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
} else {
// Failed to get the path to onnxruntime.dll. Returns an empty string.
return std::wstring{};
}
}
path.resize(path.rfind(L'\\') + 1);
return path;
}

std::once_flag run_once_before_load_deps_mutex;
std::once_flag run_once_after_load_deps_mutex;
bool dll_dir_set = false;

} // namespace

DllDelayLoadHelper::DllDelayLoadHelper() {
// Setup DLL search directory
std::call_once(run_once_before_load_deps_mutex, []() {
std::wstring path = GetCurrentDllDir();
if (!path.empty()) {
SetDllDirectoryW(path.c_str());
dll_dir_set = true;
}
});
}

DllDelayLoadHelper::~DllDelayLoadHelper() {
// Restore DLL search directory
std::call_once(run_once_after_load_deps_mutex, []() {
if (dll_dir_set) {
SetDllDirectoryW(NULL);
}
});
}

} // namespace webgpu
} // namespace onnxruntime

#else // defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)

namespace onnxruntime {
namespace webgpu {

DllDelayLoadHelper::DllDelayLoadHelper() {
}

DllDelayLoadHelper::~DllDelayLoadHelper() {
}

} // namespace webgpu
} // namespace onnxruntime

#endif
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/webgpu/dll_delay_load_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {
namespace webgpu {

// The DLL delay load helper is a RAII style guard to ensure DLL loading is done correctly.
//
// - On Windows, the helper sets the DLL search path to the directory of the current DLL.
// - On other platforms, the helper does nothing.
//
struct DllDelayLoadHelper final {
DllDelayLoadHelper();
~DllDelayLoadHelper();
};

} // namespace webgpu
} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "core/providers/webgpu/program_cache_key.h"
#include "core/providers/webgpu/program_manager.h"
#include "core/providers/webgpu/string_macros.h"
#include "core/providers/webgpu/dll_delay_load_helper.h"

namespace onnxruntime {
namespace webgpu {
Expand Down Expand Up @@ -50,6 +51,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info

// Initialization.Step.2 - Create wgpu::Adapter
if (adapter_ == nullptr) {
// DLL delay loading happens inside wgpuRequestAdapter().
// Use this helper as RAII to ensure the DLL search path is set correctly.
DllDelayLoadHelper helper{};

wgpu::RequestAdapterOptions req_adapter_options = {};
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
req_adapter_options.nextInChain = &adapter_toggles_desc;
Expand Down
Loading

0 comments on commit 7a84bf1

Please sign in to comment.