Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-Lora support #22046

Merged
merged 94 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
629071b
Add Lora Parameters schema and script
yuslepukhin Aug 28, 2024
7837eea
Add onnxruntime_lora static lib
yuslepukhin Aug 30, 2024
00fb337
Define and expose C API stubs
yuslepukhin Aug 30, 2024
8d070c9
Add loading
yuslepukhin Aug 30, 2024
2c29f64
Implement LoraAdapter and public APIs
yuslepukhin Sep 3, 2024
aec2345
Move Release to create
yuslepukhin Sep 3, 2024
7bf148d
Implement unit test
yuslepukhin Sep 4, 2024
9a8f458
Add test data creation code
yuslepukhin Sep 4, 2024
d0d71c4
Use inlined vector
yuslepukhin Sep 4, 2024
63a5109
Add vector forced alignemtn
yuslepukhin Sep 4, 2024
58a8c3e
Add Load
yuslepukhin Sep 5, 2024
3e969b3
Make test in memory
yuslepukhin Sep 5, 2024
2c8b5e2
Fix name moving
yuslepukhin Sep 5, 2024
a249784
Add OrtAllocator parameter
yuslepukhin Sep 5, 2024
c825042
Make Run() calls Lora aware
yuslepukhin Sep 5, 2024
fd02453
Add format builder
yuslepukhin Sep 5, 2024
e06bfb9
Start Python impl
yuslepukhin Sep 9, 2024
a662452
Add Lora Parameters schema and script
yuslepukhin Aug 28, 2024
889d7ca
Add onnxruntime_lora static lib
yuslepukhin Aug 30, 2024
ee3fcbc
Define and expose C API stubs
yuslepukhin Aug 30, 2024
4410f85
Add loading
yuslepukhin Aug 30, 2024
300c982
Implement LoraAdapter and public APIs
yuslepukhin Sep 3, 2024
8cc1d9d
Move Release to create
yuslepukhin Sep 3, 2024
79e62bc
Implement unit test
yuslepukhin Sep 4, 2024
bc86fab
Add test data creation code
yuslepukhin Sep 4, 2024
aab98b0
Use inlined vector
yuslepukhin Sep 4, 2024
138ab0d
Add vector forced alignemtn
yuslepukhin Sep 4, 2024
e86bd0d
Add Load
yuslepukhin Sep 5, 2024
1e47b50
Make test in memory
yuslepukhin Sep 5, 2024
a7c0ddb
Fix name moving
yuslepukhin Sep 5, 2024
321a92f
Add OrtAllocator parameter
yuslepukhin Sep 5, 2024
436d61c
Make Run() calls Lora aware
yuslepukhin Sep 5, 2024
1584a1c
Add format builder
yuslepukhin Sep 5, 2024
6ed315d
Start Python impl
yuslepukhin Sep 9, 2024
96fddbe
Add Python layer
yuslepukhin Sep 10, 2024
d58c450
Rename namespace to onnxruntime.adapters
yuslepukhin Sep 10, 2024
59805ac
Expose python level Adapter class
yuslepukhin Sep 10, 2024
b5f3633
Lint
yuslepukhin Sep 10, 2024
5c2e3b4
Add rudimentary export/read test
yuslepukhin Sep 10, 2024
135e52c
Update format signature
yuslepukhin Sep 10, 2024
3f85bdb
Add and test ort_value_from_bytes
yuslepukhin Sep 11, 2024
afbe6fa
AdapterFormat tests now pass
yuslepukhin Sep 12, 2024
1c3841f
Implement py::LoraAdapter, RunOptions and adjust run()
yuslepukhin Sep 12, 2024
8871672
Address build issues
yuslepukhin Sep 12, 2024
7c9bdba
Add convertion script. Add test model and adapter
yuslepukhin Sep 12, 2024
6447f09
Rm sample
yuslepukhin Sep 12, 2024
fae529a
Rm sample adapter
yuslepukhin Sep 12, 2024
7d337a9
Add API test and fix a bug
yuslepukhin Sep 13, 2024
6024acc
Swithch to memory map for CreateLoraAdapter
yuslepukhin Sep 13, 2024
c4c916b
Remove redandunt import
yuslepukhin Sep 13, 2024
294a895
Merge branch 'yuslepukhin/multi_lora' of https://github.com/microsoft…
yuslepukhin Sep 13, 2024
d5dedf1
Add session dep to lora, remove debug code
yuslepukhin Sep 13, 2024
14ba862
Applied lint
yuslepukhin Sep 13, 2024
71b3bbe
Fix stray windows specific declarations
yuslepukhin Sep 13, 2024
9cf6649
Remove old format files
yuslepukhin Sep 13, 2024
115a231
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 13, 2024
30f3e63
Address build issues
yuslepukhin Sep 13, 2024
583f976
Adjust linkage, fix build
yuslepukhin Sep 13, 2024
b09fab8
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 16, 2024
b670bb9
Implement CUDA device parameter copies
yuslepukhin Sep 17, 2024
3c1be76
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 20, 2024
3d48a6d
Address review comments
yuslepukhin Sep 20, 2024
0d67e2f
Add adapter test
yuslepukhin Sep 20, 2024
47d4bf2
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 21, 2024
827b381
Add base model tests
yuslepukhin Sep 21, 2024
8f51253
Merge branch 'yuslepukhin/multi_lora' of https://github.com/microsoft…
yuslepukhin Sep 21, 2024
6230f62
Lint and fix up test model path
yuslepukhin Sep 22, 2024
4b90c87
Fix CPU only bug
yuslepukhin Sep 22, 2024
38b3132
Re-work ifdefs to avoid unreachable code warning
yuslepukhin Sep 23, 2024
bbdc9ae
Add check for CPU destination
yuslepukhin Sep 23, 2024
d6594c7
Address a regression
yuslepukhin Sep 23, 2024
35c0a59
Fix build warning
yuslepukhin Sep 23, 2024
717fdd5
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 23, 2024
c0bdb9f
Merge branch 'yuslepukhin/multi_lora' of https://github.com/microsoft…
yuslepukhin Sep 23, 2024
2721ab5
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 24, 2024
34bb3a0
Remove temp CopyOnDevice
yuslepukhin Sep 24, 2024
0b569f6
Remove stray include
yuslepukhin Sep 24, 2024
2e8f7dd
Move lora_adapters files to core/session
yuslepukhin Sep 24, 2024
2daa850
Rework copy on device
yuslepukhin Sep 25, 2024
90b0197
Address review comments
yuslepukhin Sep 25, 2024
43bf431
Add pybind registration at training for Lora and remove session lora …
yuslepukhin Sep 25, 2024
d3ed0f5
Restore linkage to lora
yuslepukhin Sep 25, 2024
8076dba
Restore training linkage to lora
yuslepukhin Sep 25, 2024
9727cbe
Avoid ORT_RETURN_IF_ERROR, fails on DNNL
yuslepukhin Sep 26, 2024
f6a8404
Address review comments
yuslepukhin Sep 26, 2024
56c7e27
Add CreateLoraAdapterFromArray public API
yuslepukhin Sep 27, 2024
dd77eca
Do not ignore status
yuslepukhin Sep 27, 2024
03a27cf
Fix SAL annotation
yuslepukhin Sep 27, 2024
e122456
Address review comments, adjust tests, enforce absolute path for adap…
yuslepukhin Sep 28, 2024
d6e71b6
Fix test path to absolute
yuslepukhin Sep 28, 2024
cfd2714
Replace absolute() with current_path() + relative
yuslepukhin Sep 28, 2024
b34f75f
remove requirements for absolute path
yuslepukhin Sep 29, 2024
b1120e7
Remove proto duplicate
yuslepukhin Sep 30, 2024
6734ef3
Merge branch 'main' into yuslepukhin/multi_lora
yuslepukhin Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,7 @@ endif()

#Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next.
#The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake.
set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME})
set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME})

if (onnxruntime_USE_WINML)
# WINML uses and depends on the shared lib. Note: You can build WINML without DML and you will get a
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ set(onnxruntime_INTERNAL_LIBRARIES
onnxruntime_optimizer
onnxruntime_providers
${onnxruntime_tvm_libs}
onnxruntime_lora
onnxruntime_framework
onnxruntime_graph
onnxruntime_util
Expand Down
30 changes: 30 additions & 0 deletions cmake/onnxruntime_lora.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

file(GLOB onnxruntime_lora_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/lora_format/*.h"
"${ONNXRUNTIME_ROOT}/lora/*.h"
"${ONNXRUNTIME_ROOT}/lora/*.cc"
)

source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_lora_srcs})

onnxruntime_add_static_library(onnxruntime_lora ${onnxruntime_lora_srcs})
onnxruntime_add_include_to_target(onnxruntime_lora onnx flatbuffers::flatbuffers Boost::mp11 ${GSL_TARGET})
target_link_libraries(onnxruntime_lora onnxruntime_framework)

if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_lora PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()

target_include_directories(onnxruntime_lora PRIVATE ${ONNXRUNTIME_ROOT})
add_dependencies(onnxruntime_lora ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_lora PROPERTIES FOLDER "ONNXRuntime")

if (NOT onnxruntime_BUILD_SHARED_LIB)
install(TARGETS onnxruntime_lora
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
5 changes: 2 additions & 3 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_p

if(MSVC)
target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
if(onnxruntime_ENABLE_TRAINING)
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
endif()
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
endif()
if(HAS_CAST_FUNCTION_TYPE)
target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type")
Expand Down Expand Up @@ -186,6 +184,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE
onnxruntime_providers
onnxruntime_util
${onnxruntime_tvm_libs}
onnxruntime_lora
onnxruntime_framework
onnxruntime_util
onnxruntime_graph
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ endif()
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs})

onnxruntime_add_static_library(onnxruntime_session ${onnxruntime_session_srcs})
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface nlohmann_json::nlohmann_json)
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnxruntime_lora onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface nlohmann_json::nlohmann_json)
target_link_libraries(onnxruntime_session PRIVATE onnxruntime_lora)
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_training.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ if (onnxruntime_BUILD_UNIT_TESTS)
target_compile_options(onnxruntime_training_mnist PUBLIC "-Wno-maybe-uninitialized")
endif()
endif()
target_link_libraries(onnxruntime_training_mnist PRIVATE onnxruntime_training_runner onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES})
target_link_libraries(onnxruntime_training_mnist PRIVATE onnxruntime_training_runner onnxruntime_lora onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES})
set_target_properties(onnxruntime_training_mnist PROPERTIES FOLDER "ONNXRuntimeTest")

# squeezenet
Expand Down
13 changes: 10 additions & 3 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ file(GLOB onnxruntime_test_flatbuffers_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/flatbuffers/*.h"
)

file(GLOB onnxruntime_test_lora_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/lora/*.cc"
"${TEST_SRC_DIR}/lora/*.h"
)

if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)

file(GLOB onnxruntime_test_ir_src CONFIGURE_DEPENDS
Expand Down Expand Up @@ -612,6 +617,7 @@ set(ONNXRUNTIME_TEST_LIBS
onnxruntime_providers
onnxruntime_util
${onnxruntime_tvm_libs}
onnxruntime_lora
onnxruntime_framework
onnxruntime_util
onnxruntime_graph
Expand Down Expand Up @@ -782,7 +788,7 @@ endif()

set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src}
${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantization_src}
${onnxruntime_test_flatbuffers_src})
${onnxruntime_test_flatbuffers_src} ${onnxruntime_test_lora_src})

if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS
Expand Down Expand Up @@ -1514,6 +1520,7 @@ endif()
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util
onnxruntime_lora
onnxruntime_framework
onnxruntime_util
onnxruntime_graph
Expand Down Expand Up @@ -1634,7 +1641,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp)
list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp)
endif()
AddTest(DYN
TARGET onnxruntime_customopregistration_test
Expand Down Expand Up @@ -1753,7 +1760,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten"

set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils)
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp)
list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite ${PROTOBUF_LIB} onnx onnx_proto nsync_cpp)
endif()

if(NOT WIN32)
Expand Down
2 changes: 2 additions & 0 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
onnx
onnx_proto
onnxruntime_common
onnxruntime_lora
onnxruntime_flatbuffers
onnxruntime_framework
onnxruntime_graph
Expand Down Expand Up @@ -179,6 +180,7 @@ else()
onnx
onnx_proto
onnxruntime_common
onnxruntime_lora
onnxruntime_flatbuffers
onnxruntime_framework
onnxruntime_graph
Expand Down
2 changes: 1 addition & 1 deletion cmake/winml_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ function (get_winml_test_model_src
"${winml_test_src_path}/model/*.cpp")
set(${output_winml_test_model_src} ${winml_test_model_src} PARENT_SCOPE)
set(${winml_test_model_libs} onnx_test_data_proto onnx_test_runner_common onnxruntime_common onnxruntime_mlas
onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE)
onnxruntime_graph onnxruntime_test_utils onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE)
endfunction()

file(GLOB winml_test_common_src CONFIGURE_DEPENDS
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@

#include <string>
#include <atomic>

#include "core/common/inlined_containers_fwd.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/config_options.h"

namespace onnxruntime {
namespace lora {
class LoraAdapter;
}
} // namespace onnxruntime
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved

/**
* Configuration information for a Run call.
*/
Expand Down Expand Up @@ -40,6 +48,8 @@ struct OrtRunOptions {
// /include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
onnxruntime::ConfigOptions config_options;

onnxruntime::InlinedVector<const onnxruntime::lora::LoraAdapter*> active_adapters;

OrtRunOptions() = default;
~OrtRunOptions() = default;
};
Expand Down
52 changes: 52 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ ORT_RUNTIME_CLASS(Op);
ORT_RUNTIME_CLASS(OpAttr);
ORT_RUNTIME_CLASS(Logger);
ORT_RUNTIME_CLASS(ShapeInferContext);
ORT_RUNTIME_CLASS(LoraAdapter);

#ifdef _WIN32
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
Expand Down Expand Up @@ -4670,6 +4671,57 @@ struct OrtApi {
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);

/** \brief Create an OrtLoraAdapter
*
* The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter
* instance. The adapter_file_path should be a valid absolute path to a file that contains a valid Lora Adapter
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
* format. The function attempts to validate the format at load time. The file will always be memory mapped, unless
* the platform does not support memory mapping, in which case the file will be read into memory.
*
* \param[in] adapter_file_path adapter file path.
* \param[in] allocator optional pointer to a device allocator. If specified
* data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU.
* The data would still be copied to device if required by the model at inference time.
* \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with
* OrtApi::ReleaseLoraAdapter.
*/
ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator,
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
_Outptr_ OrtLoraAdapter** out);

/** \brief Create an OrtLoraAdapter
*
* The function copies the bytes from the array and creates an OrtLoraAdapter instance.
*
*
* \param[in] bytes pointer to a valid Lora Adapter format buffer.
* \param[in] num_bytes length of bytes buffer.
* \param[in] allocator optional pointer to a device allocator. If specified
* data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU.
* The data would still be copied to device if required by the model at inference time.
* \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with
* OrtApi::ReleaseLoraAdapter.
*/
ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);

/** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter
*/
ORT_CLASS_RELEASE(LoraAdapter);

/** \brief Add the Lora Adapter to the list of active adapters.
*
* The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with
* OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model.
* The instance of the OrtRunOptions can then be used to customize the Run() calls.
* More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different
* Lora adapters that will be active at the same time must not overlap.
* This setting does not affect RunWithBinding.
*
* \param[in] options OrtRunOptions instance
* \param[in] adapter OrtLoraAdapter instance
*/
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);
};

/*
Expand Down
35 changes: 35 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@
ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(Env);
ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(LoraAdapter);
ORT_DEFINE_RELEASE(Session);
ORT_DEFINE_RELEASE(SessionOptions);
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
Expand Down Expand Up @@ -736,6 +737,32 @@
void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
};

/// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file
struct LoraAdapter : detail::Base<OrtLoraAdapter> {
using Base = detail::Base<OrtLoraAdapter>;
using Base::Base;

explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used

Check warning on line 745 in include/onnxruntime/core/session/onnxruntime_cxx_api.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api.h:745: Lines should be <= 120 characters long [whitespace/line_length] [2]
/// \brief Wraps OrtApi::CreateLoraAdapter
///
/// The function attempts to load the adapter from the specified file
/// \param adapter_path The path to the Lora adapter
/// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
/// be copied to device if required by the model at inference time.
static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
OrtAllocator* allocator);

/// \brief Wraps OrtApi::CreateLoraAdapterFromArray
///
/// The function attempts to load the adapter from the specified byte array.
/// \param bytes The byte array containing file LoraAdapter format
/// \param num_bytes The number of bytes in the byte array
/// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
/// be copied to device if required by the model at inference time.
static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
OrtAllocator* allocator);
};

/** \brief RunOptions
*
*/
Expand Down Expand Up @@ -766,6 +793,14 @@
* Wraps OrtApi::RunOptionsUnsetTerminate
*/
RunOptions& UnsetTerminate();

/** \brief Add the LoraAdapter to the list of active adapters.
* The setting does not affect RunWithBinding() calls.
*
* Wraps OrtApi::RunOptionsSetLoraAdapterActive
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
* \param adapter The LoraAdapter to be used as the active adapter
*/
RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);
};

namespace detail {
Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,20 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
}

inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
OrtAllocator* allocator) {
OrtLoraAdapter* p;
ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p));
return LoraAdapter{p};
}

inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
OrtAllocator* allocator) {
OrtLoraAdapter* p;
ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p));
return LoraAdapter{p};
}

inline RunOptions::RunOptions() {
ThrowOnError(GetApi().CreateRunOptions(&p_));
}
Expand Down Expand Up @@ -609,6 +623,11 @@ inline RunOptions& RunOptions::UnsetTerminate() {
return *this;
}

inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) {
ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter));
return *this;
}

namespace detail {

template <typename T>
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onnxruntime.capi._pybind_state import ExecutionMode # noqa: F401
from onnxruntime.capi._pybind_state import ExecutionOrder # noqa: F401
from onnxruntime.capi._pybind_state import GraphOptimizationLevel # noqa: F401
from onnxruntime.capi._pybind_state import LoraAdapter # noqa: F401
from onnxruntime.capi._pybind_state import ModelMetadata # noqa: F401
from onnxruntime.capi._pybind_state import NodeArg # noqa: F401
from onnxruntime.capi._pybind_state import OrtAllocatorType # noqa: F401
Expand Down Expand Up @@ -56,6 +57,7 @@
if import_capi_exception:
raise import_capi_exception

from onnxruntime.capi.onnxruntime_inference_collection import AdapterFormat # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/config_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct ConfigOptions {

// Gets the config string associated with the given config_key.
// If not found, an empty optional is returned.
optional<std::string> GetConfigEntry(const std::string& config_key) const noexcept;
std::optional<std::string> GetConfigEntry(const std::string& config_key) const noexcept;

// Check if this instance of ConfigOptions has a config using the given config_key.
// Returns true if found and copies the value into config_value.
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/framework/run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
#include "core/session/onnxruntime_c_api.h"
#include "core/session/ort_apis.h"
#include "core/framework/error_code_helper.h"

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(disable : 26409)
#endif

ORT_API_STATUS_IMPL(OrtApis::CreateRunOptions, _Outptr_ OrtRunOptions** out) {
API_IMPL_BEGIN
*out = new OrtRunOptions();
Expand Down Expand Up @@ -60,3 +62,12 @@ ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options,
_In_z_ const char* config_key, _In_z_ const char* config_value) {
return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value));
}

ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options,
const _In_ OrtLoraAdapter* adapter) {
API_IMPL_BEGIN
auto* lora_adapter = reinterpret_cast<const onnxruntime::lora::LoraAdapter*>(adapter);
options->active_adapters.push_back(lora_adapter);
return nullptr;
API_IMPL_END
}
Loading
Loading