From 69a36eb231fcead986b22da952d4d7b5f5c86af9 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Tue, 12 Nov 2024 17:45:59 -0500 Subject: [PATCH] Revert Implement DML copy for Lora Adapters (#22814) Revert https://github.com/microsoft/onnxruntime/pull/22396 --- .../src/BucketizedBufferAllocator.cpp | 2 +- .../src/BucketizedBufferAllocator.h | 6 -- .../DmlExecutionProvider/src/CommandQueue.cpp | 10 +-- .../DmlExecutionProvider/src/CommandQueue.h | 4 +- .../src/ExecutionContext.cpp | 37 +++++++++- .../src/ExecutionContext.h | 12 +++- .../src/ExecutionProvider.cpp | 21 +----- .../providers/dml/dml_provider_factory.cc | 4 +- onnxruntime/core/session/lora_adapters.cc | 71 +++++-------------- .../python/onnxruntime_pybind_mlvalue.cc | 2 +- onnxruntime/test/lora/lora_test.cc | 66 ++--------------- 11 files changed, 78 insertions(+), 157 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 68b9b3fe5706f..334a40b979bda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -182,7 +182,7 @@ namespace Dml } else { - if (!m_closed) + if (!m_context->IsClosed()) { // Free the underlying allocation once queued work has completed. #ifdef _GAMING_XBOX diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 65bc9b7f69316..16283d5b19c9c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -46,11 +46,6 @@ namespace Dml void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); - void Close() - { - m_closed = true; - } - public: // onnxruntime::IAllocator void* Alloc(size_t size, AllocatorRoundingMode roundingMode); void* Alloc(size_t size) final; @@ -88,7 +83,6 @@ namespace Dml std::vector m_pool; size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; - bool m_closed = false; // Unless specifically requested, allocation sizes are not rounded to enable pooling // until SetDefaultRoundingMode is called. This should be done at completion of session diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp index 67faf333d21e1..988324bab1174 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp @@ -55,7 +55,7 @@ namespace Dml // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference // to its underlying D3D resource when freed. Furthermore, these references are unnecessary // since Close() already blocks for scheduled GPU work before clearing m_queuedReferences. - if (!m_clearingQueue) + if (!m_closing) { QueuedReference queuedReference = {GetLastFenceValue(), object}; @@ -70,15 +70,15 @@ namespace Dml } } - void CommandQueue::WaitForSignalAndClearQueue() + void CommandQueue::Close() { // Wait for flushed work: - assert(!m_clearingQueue); - m_clearingQueue = true; + assert(!m_closing); + m_closing = true; GpuEvent event = GetCurrentCompletionEvent(); event.WaitForSignal(m_cpuSyncSpinningEnabled); m_queuedReferences.clear(); - m_clearingQueue = false; + m_closing = false; } void CommandQueue::ReleaseCompletedReferences() diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h index 9a4728d5845d4..71d5eb173cfec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h @@ -44,7 +44,7 @@ namespace Dml } #endif - void WaitForSignalAndClearQueue(); + void Close(); void ReleaseCompletedReferences(); private: @@ -61,7 +61,7 @@ namespace Dml ComPtr m_fence; uint64_t m_lastFenceValue = 0; - bool m_clearingQueue = false; + bool m_closing = false; bool m_cpuSyncSpinningEnabled = false; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index ececf13fc8cdf..5dc1213bd76f0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -11,10 +11,13 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled) + bool cpuSyncSpinningEnabled, + bool keepOpen + ) : m_queue(std::make_shared(queue, cpuSyncSpinningEnabled)) , m_dmlRecorder(d3d12Device, dmlDevice, m_queue) , m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled) + , m_keepOpen(keepOpen) { ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } @@ -33,6 +36,8 @@ namespace Dml D3D12_RESOURCE_STATES srcState, uint64_t byteCount) { + assert(!m_closed); + SetCommandRecorder(&m_dmlRecorder); std::vector barriers; @@ -79,6 +84,8 @@ namespace Dml _Out_ uint64_t* completionValue ) { + assert(!m_closed); + SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue); } @@ -88,6 +95,7 @@ namespace Dml const DML_BINDING_DESC& persistentResourceBinding, const DML_BINDING_DESC& inputArrayBinding) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding); @@ -99,6 +107,7 @@ namespace Dml gsl::span inputBindings, gsl::span outputBindings) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings); @@ -106,6 +115,7 @@ namespace Dml void ExecutionContext::AddUAVBarrier() { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.AddUAVBarrier(); @@ -113,6 +123,7 @@ namespace Dml void ExecutionContext::ResourceBarrier(gsl::span barriers) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ResourceBarrier(barriers); @@ -120,6 +131,7 @@ namespace Dml void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); // Ensure the descriptor heap is reset to D3D as something external may change it before recording @@ -130,6 +142,8 @@ namespace Dml void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder) { + assert(!m_closed); + // If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct // ordering of operations on the command queue. if (m_currentRecorder != newRecorder) @@ -146,6 +160,8 @@ namespace Dml void ExecutionContext::Flush() { + assert(!m_closed); + if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork()) { // Nothing to flush @@ -164,21 +180,34 @@ namespace Dml void ExecutionContext::QueueReference(IUnknown* object) { + assert(!m_closed); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence // value is the one to signal completion. bool waitForUnsubmittedWork = (m_currentRecorder != nullptr); m_queue->QueueReference(object, waitForUnsubmittedWork); } - void ExecutionContext::WaitForSignalAndClearQueue() + void ExecutionContext::Close() { + assert(!m_closed); + // Discard unflushed work and clear queued references. This prevents the circular reference: // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel - m_queue->WaitForSignalAndClearQueue(); + m_queue->Close(); + + // Keep the execution context open when requested, e.g. when used through the python API where there's a single context + // and single command queue + if (!m_keepOpen) + { + m_currentRecorder = nullptr; + m_closed = true; + } } GpuEvent ExecutionContext::GetCurrentCompletionEvent() { + assert(!m_closed); + GpuEvent event = m_queue->GetCurrentCompletionEvent(); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence @@ -194,11 +223,13 @@ namespace Dml void ExecutionContext::ReleaseCompletedReferences() { + assert(!m_closed); m_queue->ReleaseCompletedReferences(); } D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const { + assert(!m_closed); return m_queue->GetType(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index 71aa26f4a0148..e7a6fa3d07296 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -23,13 +23,14 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled); + bool cpuSyncSpinningEnabled, + bool keepOpen); void SetAllocator(std::weak_ptr allocator); // Waits for flushed work, discards unflushed work, and discards associated references to - // prevent circular references. - void WaitForSignalAndClearQueue(); + // prevent circular references. Must be the last call on the object before destruction. + void Close(); // Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition // barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and @@ -86,6 +87,7 @@ namespace Dml D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } + bool IsClosed() const { return m_closed; } private: Microsoft::WRL::ComPtr m_d3dDevice; @@ -101,6 +103,10 @@ namespace Dml bool m_closed = false; bool m_cpuSyncSpinningEnabled = false; + + // The python API has a global state used for I/O binding where the execution context is shared between session, + // so we don't want to close the context when one of the sessions is destroyed + bool m_keepOpen = false; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index ae9be4ea91c28..228dfeb123175 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -106,26 +106,7 @@ namespace Dml // Release the cached command list references before closing the context m_capturedGraphs.clear(); - // Close the allocator before clearing the command queue to stop it from - // appending resources to it in an attempt to keep them alive. - if (m_allocator) - { - m_allocator->Close(); - } - - // Destroy the allocators. We are closing the execution provider, so from now on the - // only thing it will be used for is doing copies via the DataTransfer, which doesn't - // require allocating any memory. - // TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly - // destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive. - m_allocator = nullptr; - m_cpuInputAllocator = nullptr; - - // Wait for all pending commands to be done executing and empty the command queue. This will - // Force all kernels and resources in flight to get destroyed and, from this point forward, - // ExecutionProviderImpl will only be used to execute transfer between resources that are - // already existing via the DataTransfer; - m_context->WaitForSignalAndClearQueue(); + m_context->Close(); } void ExecutionProviderImpl::WaitForOutstandingWork() diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 89decfef6fef6..e8fe235fc1d46 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -86,11 +86,11 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { // First, check if an I/O binding API that was used before this session or another session has already created a queue if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get())); } } else { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); } auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 599c41f79a537..466edce187a56 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -4,9 +4,10 @@ #include "core/session/lora_adapters.h" #include "lora/adapter_format_utils.h" +#include + #include "core/framework/data_transfer.h" #include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" @@ -15,15 +16,6 @@ #include "core/providers/cuda/cuda_provider_factory.h" #endif -#ifdef USE_DML -#include "core/session/abi_session_options_impl.h" -#include "core/providers/dml/dml_provider_factory_creator.h" -#include "core/providers/dml/dml_provider_factory.h" -#endif - -#include -#include - namespace onnxruntime { #ifdef USE_CUDA @@ -58,58 +50,28 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { InitializeParamsValues(); } -namespace { -struct DataTransfer { - std::unique_ptr ep; +static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { std::unique_ptr data_transfer; - bool is_dml = false; - Status CopyTensor(const Tensor& src, Tensor& dst) const { - return data_transfer->CopyTensor(src, dst); - } - Status Sync() const { - if (is_dml) { - return ep->Sync(); - } else { - return Status::OK(); - } - } -}; -} // namespace -static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) { - ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter"); + if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { + return data_transfer; + } - Status status; if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { #ifdef USE_CUDA auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { - dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer(); - } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded"); + data_transfer = cuda_provider_info->CreateGPUDataTransfer(); } -#else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build"); -#endif - } else if (strcmp(mem_info.name, onnxruntime::DML) == 0) { -#ifdef USE_DML - auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false); - dt.ep = ep_factory->CreateProvider(); - dt.is_dml = true; - dt.data_transfer = dt.ep->GetDataTransfer(); -#else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build"); #endif - } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator"); } - return status; + return data_transfer; } static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped, const AllocatorPtr& device_allocator, - const DataTransfer& data_transfer, + const IDataTransfer& data_transfer, OrtValue& out) { OrtValue result; const auto& src = ort_value_mapped.Get(); @@ -125,9 +87,12 @@ void LoraAdapter::InitializeParamsValues() { ORT_THROW("Adapter is not loaded yet."); } - DataTransfer data_transfer; + std::unique_ptr data_transfer; if (device_allocator_) { - ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer)); + data_transfer = GetDataTransfer(device_allocator_->Info()); + if (data_transfer == nullptr) { + ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator"); + } } const auto* params = adapter_->parameters(); @@ -135,12 +100,12 @@ void LoraAdapter::InitializeParamsValues() { std::unordered_map params_values; params_values.reserve(params->size()); // Re-work in two separate loops due to compiler issues - if (device_allocator_) { + if (data_transfer) { for (const auto* param : *params) { auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); OrtValue ort_value_ondevice; ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_, - data_transfer, ort_value_ondevice)); + *data_transfer, ort_value_ondevice)); Param lora_param(std::move(ort_value), std::move(ort_value_ondevice)); params_values.emplace(std::move(name), std::move(lora_param)); } @@ -152,10 +117,6 @@ void LoraAdapter::InitializeParamsValues() { } } - if (device_allocator_) { - ORT_THROW_IF_ERROR(data_transfer.Sync()); - } - params_values_.swap(params_values); } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 74bd20461efea..92396bb09bd4c 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get())); - context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true); + context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get())); } diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index 8338c7d547a09..e8291a36447ca 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -200,19 +200,13 @@ TEST(LoraAdapterTest, Load) { } #ifdef USE_CUDA -TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { - if (DefaultCudaExecutionProvider() == nullptr) { - GTEST_SKIP() << "Skip This Test Due to this EP is null"; - } -#ifdef USE_DML - if (DefaultDmlExecutionProvider() != nullptr) { - GTEST_FAIL() << "It should not run with DML EP"; - } -#endif +TEST(LoraAdapterTest, VerifyDeviceCopy) { auto cpu_ep = DefaultCpuExecutionProvider(); auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; - auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0]; - auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer(); + auto cuda_ep = DefaultCudaExecutionProvider(); + auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0]; + + auto gpu_transfer = cuda_ep->GetDataTransfer(); auto test_params = GenerateTestParameters()(); lora::LoraAdapter adapter(std::move(cuda_allocator)); @@ -228,54 +222,9 @@ TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); - ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device, - copy.Location().device)); - ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy)); - - auto expected_span = tensor_cpu.DataAsSpan(); - auto copy_span = copy.DataAsSpan(); - - ASSERT_EQ(expected_span, copy_span); - } -} -#endif - -#ifdef USE_DML -TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { - // NO_DML_TEST is set, DML test is skipped - if (DefaultDmlExecutionProvider() == nullptr) { - GTEST_SKIP() << "Skip This Test Due to this EP is null"; - } - -#ifdef USE_CUDA - if (DefaultCudaExecutionProvider() != nullptr) { - GTEST_FAIL() << "It should not run with CUDA EP"; - } -#endif - - auto cpu_ep = DefaultCpuExecutionProvider(); - auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; - - auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0]; - auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer(); - - auto test_params = GenerateTestParameters()(); - lora::LoraAdapter adapter(std::move(dml_allocator)); - adapter.Load(std::move(test_params)); - - auto [begin, end] = adapter.GetParamIterators(); - for (; begin != end; ++begin) { - const auto& [_, param] = *begin; - const auto& tensor_device = param.GetDeviceOrMapped().Get(); - ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML)); - - const auto& tensor_cpu = param.GetMapped().Get(); - ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); - - Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); - ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device, + ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device, copy.Location().device)); - ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy)); + ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy)); auto expected_span = tensor_cpu.DataAsSpan(); auto copy_span = copy.DataAsSpan(); @@ -284,6 +233,5 @@ TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { } } #endif - } // namespace test } // namespace onnxruntime