Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize shared memory mapping and copy for input and output buffers #10

Merged
merged 3 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions content/browser/ml/webnn/dml/adapter_dml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "content/browser/ml/webnn/dml/adapter_dml.h"

#pragma optimize("", off) // TODO:::DELETE

namespace content::webnn {

AdapterDML::AdapterDML(ComPtr<IDXGIAdapter3> hardware_adapter)
Expand Down
4 changes: 4 additions & 0 deletions content/browser/ml/webnn/dml/command_recorder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "content/browser/ml/webnn/dml/command_recorder.h"

#include "base/trace_event/trace_event.h"
#include "base/trace_event/typed_macros.h"
#include "content/browser/ml/webnn/dml/adapter_dml.h"
#include "content/browser/ml/webnn/dml/execution_resources.h"

Expand Down Expand Up @@ -65,6 +67,7 @@ HRESULT CommandRecorder::InitializeGraph(
GraphDMLImpl* graph,
IDMLCompiledOperator* compiled_operator,
const DML_BINDING_DESC& input_array_binding) {
TRACE_EVENT0("gpu", "CommandRecorder::InitializeGraph");
// Reset the initializer to reference the compiled operator.
IDMLCompiledOperator* ops[] = {compiled_operator};
HRESULT hr = operator_initializer_->Reset(ARRAYSIZE(ops), ops);
Expand Down Expand Up @@ -164,6 +167,7 @@ HRESULT CommandRecorder::ExecuteGraph(
IDMLCompiledOperator* compiled_operator,
const std::vector<DML_BINDING_DESC>& input_bindings,
const std::vector<DML_BINDING_DESC>& output_bindings) {
TRACE_EVENT0("gpu", "CommandRecorder::ExecuteGraph");
DCHECK(mBindingTable != nullptr);
// Bind and execute the operator on the GPU.
// Reset the binding table to bind for the operator we want to execute (it
Expand Down
2 changes: 0 additions & 2 deletions content/browser/ml/webnn/dml/graph_desc_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#include "base/logging.h"
#include "content/browser/ml/webnn/dml/graph_desc_builder.h"

#pragma optimize("", off) // TODO:::DELETE

namespace content::webnn {

GraphDescBuilder::GraphDescBuilder(ComPtr<IDMLDevice> device)
Expand Down
8 changes: 4 additions & 4 deletions content/browser/ml/webnn/dml/graph_dml_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

#include "content/browser/ml/webnn/dml/graph_dml_impl.h"

#include "base/containers/span.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/containers/span.h"
#include "base/trace_event/trace_event.h"
#include "base/trace_event/typed_macros.h"
#include "content/browser/ml/webnn/dml/execution_context.h"
#include "content/browser/ml/webnn/dml/execution_resources.h"
#include "content/browser/ml/webnn/dml/graph_dml_impl.h"
#include "content/browser/ml/webnn/dml/upload_resource.h"
#include "mojo/public/c/system/types.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"

// TODO:::DELETE
#pragma optimize("", off) // TODO:::DELETE

namespace content::webnn {

namespace {
Expand Down Expand Up @@ -1858,6 +1857,7 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) {

void GraphDMLImpl::Compute(NamedResourcesPtr named_inputs,
ComputeCallback callback) {
TRACE_EVENT0("gpu", "GraphDMLImpl::Compute");
ExecutionResources* execution_resources =
execution_context_->GetExecutionResources();
ID3D12Resource* inputs_resource =
Expand Down
2 changes: 0 additions & 2 deletions content/browser/ml/webnn/dml/graph_tensor_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include "base/numerics/checked_math.h"
#include "base/containers/span.h"

#pragma optimize("", off) // TODO:::DELETE

namespace content::webnn {

namespace {
Expand Down
26 changes: 11 additions & 15 deletions content/browser/ml/webnn/dml/readback_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

#include <memory>

#include "base/trace_event/trace_event.h"
#include "base/trace_event/typed_macros.h"
#include "content/browser/ml/webnn/dml/execution_context.h"

#pragma optimize("", off) // TODO:::DELETE

namespace content::webnn {

ReadbackResource::ReadbackResource(ExecutionContext* execution_context)
Expand Down Expand Up @@ -44,6 +44,7 @@ HRESULT ReadbackResource::InitializeResource(
// Readback inference result from GPU that is stored in named_outputs.
HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs,
ID3D12Resource* src_resource) {
TRACE_EVENT0("gpu", "ReadbackResource::ReadResourceFromGpu");
// Copy buffer from GPU resource to CPU data.
execution_context_->CopyBufferRegion(readback_resource_->GetResource(),
src_resource, outputs_resource_size_,
Expand All @@ -53,30 +54,25 @@ HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs,
execution_context_->WaitForSignal();
execution_context_->ReleaseCompletedResources();

D3D12_RANGE tensorBufferRange{0, outputs_resource_size_};
int8_t* readBackBuffer;
D3D12_RANGE read_range{0, outputs_resource_size_};
int8_t* readback_buffer;
HRESULT hr = readback_resource_->Map(
0, &tensorBufferRange, reinterpret_cast<void**>(&readBackBuffer));
0, &read_range, reinterpret_cast<void**>(&readback_buffer));
if (FAILED(hr)) {
return hr;
}
uint8_t* address = outputs_shm_region_.mapping.GetMemoryAs<uint8_t>();
memcpy(address, readback_buffer, outputs_resource_size_);
readback_resource_->Unmap(0, nullptr);

for (auto& [name, memory_info] : outputs_info_map_) {
auto mojo_memory_info = ml::webnn::mojom::MemoryInfo::New();
size_t byte_offset = memory_info.byte_offset;
size_t byte_length = memory_info.byte_length;
mojo_memory_info->byte_offset = byte_offset;
mojo_memory_info->byte_length = byte_length;
mojo_memory_info->byte_offset = memory_info.byte_offset;
mojo_memory_info->byte_length = memory_info.byte_length;
named_outputs->resources[name] = std::move(mojo_memory_info);

std::vector<uint8_t> output_buffer(byte_length);
uint8_t* address =
outputs_shm_region_.mapping.GetMemoryAs<uint8_t>() + byte_offset;
memcpy(address, readBackBuffer + byte_offset, byte_length);
}
named_outputs->shared_memory = outputs_shm_region_.region.Duplicate();

readback_resource_->Unmap(0, nullptr);
return S_OK;
}

Expand Down
45 changes: 23 additions & 22 deletions content/browser/ml/webnn/dml/upload_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <memory>

#include "base/trace_event/trace_event.h"
#include "base/trace_event/typed_macros.h"
#include "content/browser/ml/webnn/dml/execution_context.h"

namespace content::webnn {
Expand All @@ -14,36 +16,25 @@ namespace {

using ml::webnn::mojom::MemoryInfoPtr;

template <typename T>
HRESULT UploadResourceToGpu(
ExecutionContext* execution_context,
ID3D12Resource* dst_resource,
ID3D12Resource* src_resource,
base::ReadOnlySharedMemoryRegion& shared_memory_region,
T& named_inputs) {
base::ReadOnlySharedMemoryMapping& shared_memory_mapping,
size_t byte_length) {
// Map the upload heap and copy the source data into it. A null pointer
// indicates the entire subresource might be read by the CPU.
void* upload_data = nullptr;
HRESULT hr = src_resource->Map(0, nullptr, &upload_data);
if (FAILED(hr)) {
return hr;
}

for (auto& [_, memory_info] : named_inputs) {
uint64_t byte_length = memory_info->byte_length;
uint64_t byte_offset = memory_info->byte_offset;
DCHECK(byte_offset % DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT == 0);
DCHECK(shared_memory_region.IsValid());
base::ReadOnlySharedMemoryMapping shared_memory_mapping =
shared_memory_region.MapAt(memory_info->byte_offset, byte_length);
memcpy(static_cast<byte*>(upload_data) + memory_info->byte_offset,
shared_memory_mapping.GetMemoryAs<uint8_t>(), byte_length);
}
memcpy(static_cast<byte*>(upload_data),
shared_memory_mapping.GetMemoryAs<uint8_t>(), byte_length);
src_resource->Unmap(0, nullptr);

// Copy from the upload heap into the destination resource
execution_context->CopyBufferRegion(dst_resource, src_resource,
shared_memory_region.GetSize(),
execution_context->CopyBufferRegion(dst_resource, src_resource, byte_length,
D3D12_RESOURCE_STATE_COPY_DEST);

return S_OK;
Expand All @@ -60,9 +51,14 @@ UploadResource::~UploadResource() = default;
// need to transition.
HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource,
ConstantsInfoPtr& constants_info) {
TRACE_EVENT0("gpu", "UploadResource::UploadConstants");
base::ReadOnlySharedMemoryRegion& shared_memory_region =
constants_info->shared_memory;
size_t constants_byte_length = shared_memory_region.GetSize();
if (!shm_mapping_.IsValid()) {
shm_mapping_ = shared_memory_region.Map();
DCHECK(shm_mapping_.IsValid());
}

HRESULT hr = S_OK;
if (upload_resource_ == nullptr) {
Expand All @@ -73,16 +69,21 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource,
}
DCHECK(upload_resource_ != nullptr);

return UploadResourceToGpu<base::flat_map<UINT64, MemoryInfoPtr>>(
execution_context_, dst_resource, upload_resource_->GetResource(),
shared_memory_region, constants_info->memory_info);
return UploadResourceToGpu(execution_context_, dst_resource,
upload_resource_->GetResource(), shm_mapping_,
constants_byte_length);
}

HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource,
NamedResourcesPtr& named_inputs) {
TRACE_EVENT0("gpu", "UploadResource::UploadInputs");
base::ReadOnlySharedMemoryRegion& shared_memory_region =
named_inputs->shared_memory;
size_t inputs_byte_length = shared_memory_region.GetSize();
if (!shm_mapping_.IsValid()) {
shm_mapping_ = shared_memory_region.Map();
DCHECK(shm_mapping_.IsValid());
}

HRESULT hr = S_OK;
if (upload_resource_ == nullptr) {
Expand All @@ -93,9 +94,9 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource,
}
DCHECK(upload_resource_ != nullptr);

return UploadResourceToGpu<base::flat_map<std::string, MemoryInfoPtr>>(
execution_context_, dst_resource, upload_resource_->GetResource(),
shared_memory_region, named_inputs->resources);
return UploadResourceToGpu(execution_context_, dst_resource,
upload_resource_->GetResource(), shm_mapping_,
inputs_byte_length);
}

// Create entire memory for uploading resource that will be uploaded piece by
Expand Down
1 change: 1 addition & 0 deletions content/browser/ml/webnn/dml/upload_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class UploadResource final {

ExecutionContext* execution_context_;
ComPtr<gpgmm::d3d12::ResourceAllocation> upload_resource_;
base::ReadOnlySharedMemoryMapping shm_mapping_;
};

} // namespace content::webnn
Expand Down
2 changes: 0 additions & 2 deletions third_party/blink/renderer/modules/ml/webnn/ml_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "third_party/blink/renderer/platform/heap/collection_support/heap_deque.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_hash_set.h"

#pragma optimize("", off) // TODO:::DELETE

namespace blink {

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.h"
#endif

#pragma optimize("", off) // TODO:::DELETE

namespace blink {

namespace {
Expand Down
2 changes: 0 additions & 2 deletions third_party/blink/renderer/modules/ml/webnn/ml_operand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"

#pragma optimize("", off) // TODO:::DELETE

namespace blink {

namespace {
Expand Down
Loading