Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DML copy for Lora Adapters #22396

Merged
merged 8 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ namespace Dml
}
else
{
if (!m_context->IsClosed())
if (!m_closed)
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ namespace Dml

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

void Close()
{
m_closed = true;
}

public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
Expand Down Expand Up @@ -83,6 +88,7 @@ namespace Dml
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
bool m_closed = false;

// Unless specifically requested, allocation sizes are not rounded to enable pooling
// until SetDefaultRoundingMode is called. This should be done at completion of session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace Dml
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
if (!m_closing)
if (!m_clearingQueue)
{
QueuedReference queuedReference = {GetLastFenceValue(), object};

Expand All @@ -70,15 +70,15 @@ namespace Dml
}
}

void CommandQueue::Close()
void CommandQueue::WaitForSignalAndClearQueue()
{
// Wait for flushed work:
assert(!m_closing);
m_closing = true;
assert(!m_clearingQueue);
m_clearingQueue = true;
GpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal(m_cpuSyncSpinningEnabled);
m_queuedReferences.clear();
m_closing = false;
m_clearingQueue = false;
}

void CommandQueue::ReleaseCompletedReferences()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace Dml
}
#endif

void Close();
void WaitForSignalAndClearQueue();
void ReleaseCompletedReferences();

private:
Expand All @@ -61,7 +61,7 @@ namespace Dml

ComPtr<ID3D12Fence> m_fence;
uint64_t m_lastFenceValue = 0;
bool m_closing = false;
bool m_clearingQueue = false;
bool m_cpuSyncSpinningEnabled = false;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen
)
bool cpuSyncSpinningEnabled)
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
, m_keepOpen(keepOpen)
{
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}
Expand All @@ -36,8 +33,6 @@ namespace Dml
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount)
{
assert(!m_closed);

SetCommandRecorder(&m_dmlRecorder);

std::vector<D3D12_RESOURCE_BARRIER> barriers;
Expand Down Expand Up @@ -84,8 +79,6 @@ namespace Dml
_Out_ uint64_t* completionValue
)
{
assert(!m_closed);

SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
}
Expand All @@ -95,7 +88,6 @@ namespace Dml
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
Expand All @@ -107,31 +99,27 @@ namespace Dml
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
}

void ExecutionContext::AddUAVBarrier()
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.AddUAVBarrier();
}

void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.ResourceBarrier(barriers);
}

void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

// Ensure the descriptor heap is reset to D3D as something external may change it before recording
Expand All @@ -142,8 +130,6 @@ namespace Dml

void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
{
assert(!m_closed);

// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
// ordering of operations on the command queue.
if (m_currentRecorder != newRecorder)
Expand All @@ -160,8 +146,6 @@ namespace Dml

void ExecutionContext::Flush()
{
assert(!m_closed);

if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
{
// Nothing to flush
Expand All @@ -180,34 +164,21 @@ namespace Dml

void ExecutionContext::QueueReference(IUnknown* object)
{
assert(!m_closed);
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
m_queue->QueueReference(object, waitForUnsubmittedWork);
}

void ExecutionContext::Close()
void ExecutionContext::WaitForSignalAndClearQueue()
{
assert(!m_closed);

// Discard unflushed work and clear queued references. This prevents the circular reference:
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
m_queue->Close();

// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
// and single command queue
if (!m_keepOpen)
{
m_currentRecorder = nullptr;
m_closed = true;
}
m_queue->WaitForSignalAndClearQueue();
}

GpuEvent ExecutionContext::GetCurrentCompletionEvent()
{
assert(!m_closed);

GpuEvent event = m_queue->GetCurrentCompletionEvent();

// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
Expand All @@ -223,13 +194,11 @@ namespace Dml

void ExecutionContext::ReleaseCompletedReferences()
{
assert(!m_closed);
m_queue->ReleaseCompletedReferences();
}

D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
{
assert(!m_closed);
return m_queue->GetType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen);
bool cpuSyncSpinningEnabled);

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// prevent circular references.
void WaitForSignalAndClearQueue();

// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
Expand Down Expand Up @@ -87,7 +86,6 @@ namespace Dml

D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
bool IsClosed() const { return m_closed; }

private:
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
Expand All @@ -103,10 +101,6 @@ namespace Dml

bool m_closed = false;
bool m_cpuSyncSpinningEnabled = false;

// The python API has a global state used for I/O binding where the execution context is shared between session,
// so we don't want to close the context when one of the sessions is destroyed
bool m_keepOpen = false;
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,26 @@
// Release the cached command list references before closing the context
m_capturedGraphs.clear();

m_context->Close();
// Close the allocator before clearing the command queue to stop it from
// appending resources to it in an attempt to keep them alive.
if (m_allocator)
{
m_allocator->Close();
}

// Destroy the allocators. We are closing the execution provider, so from now on the
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
// require allocating any memory.
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly

Check warning on line 119 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp:119: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
m_allocator = nullptr;
m_cpuInputAllocator = nullptr;

// Wait for all pending commands to be done executing and empty the command queue. This will
// Force all kernels and resources in flight to get destroyed and, from this point forward,
// ExecutionProviderImpl will only be used to execute transfer between resources that are
// already existing via the DataTransfer;
m_context->WaitForSignalAndClearQueue();
}

void ExecutionProviderImpl::WaitForOutstandingWork()
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {

// First, check if an I/O binding API that was used before this session or another session has already created a queue
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
}
} else {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
}

auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);
Expand Down
Loading
Loading