From 1ca4871cfef56ba4e10648baf67859afafa52336 Mon Sep 17 00:00:00 2001 From: Ningxin Hu Date: Mon, 22 May 2023 08:43:57 +0800 Subject: [PATCH 1/3] Delete #pragma optimize("", off) --- content/browser/ml/webnn/dml/adapter_dml.cc | 2 -- content/browser/ml/webnn/dml/graph_desc_builder.cc | 2 -- content/browser/ml/webnn/dml/graph_dml_impl.cc | 3 --- content/browser/ml/webnn/dml/graph_tensor_desc.cc | 2 -- content/browser/ml/webnn/dml/readback_resource.cc | 2 -- third_party/blink/renderer/modules/ml/webnn/ml_graph.cc | 2 -- .../blink/renderer/modules/ml/webnn/ml_graph_builder.cc | 2 -- third_party/blink/renderer/modules/ml/webnn/ml_operand.cc | 2 -- third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc | 2 -- third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc | 2 -- 10 files changed, 21 deletions(-) diff --git a/content/browser/ml/webnn/dml/adapter_dml.cc b/content/browser/ml/webnn/dml/adapter_dml.cc index 9bfcf9845ddce4..60a8fb13b7572d 100644 --- a/content/browser/ml/webnn/dml/adapter_dml.cc +++ b/content/browser/ml/webnn/dml/adapter_dml.cc @@ -4,8 +4,6 @@ #include "content/browser/ml/webnn/dml/adapter_dml.h" -#pragma optimize("", off) // TODO:::DELETE - namespace content::webnn { AdapterDML::AdapterDML(ComPtr hardware_adapter) diff --git a/content/browser/ml/webnn/dml/graph_desc_builder.cc b/content/browser/ml/webnn/dml/graph_desc_builder.cc index b78d1298b9a273..8f863b07408913 100644 --- a/content/browser/ml/webnn/dml/graph_desc_builder.cc +++ b/content/browser/ml/webnn/dml/graph_desc_builder.cc @@ -5,8 +5,6 @@ #include "base/logging.h" #include "content/browser/ml/webnn/dml/graph_desc_builder.h" -#pragma optimize("", off) // TODO:::DELETE - namespace content::webnn { GraphDescBuilder::GraphDescBuilder(ComPtr device) diff --git a/content/browser/ml/webnn/dml/graph_dml_impl.cc b/content/browser/ml/webnn/dml/graph_dml_impl.cc index 1cd3298f3e0dff..0c4a450424c8b8 100644 --- a/content/browser/ml/webnn/dml/graph_dml_impl.cc +++ b/content/browser/ml/webnn/dml/graph_dml_impl.cc @@ -14,9 +14,6 @@ #include "mojo/public/c/system/types.h" #include "mojo/public/cpp/bindings/self_owned_receiver.h" -// TODO:::DELETE -#pragma optimize("", off) // TODO:::DELETE - namespace content::webnn { namespace { diff --git a/content/browser/ml/webnn/dml/graph_tensor_desc.cc b/content/browser/ml/webnn/dml/graph_tensor_desc.cc index 2253e684fe3a82..959584f6191f76 100644 --- a/content/browser/ml/webnn/dml/graph_tensor_desc.cc +++ b/content/browser/ml/webnn/dml/graph_tensor_desc.cc @@ -8,8 +8,6 @@ #include "base/numerics/checked_math.h" #include "base/containers/span.h" -#pragma optimize("", off) // TODO:::DELETE - namespace content::webnn { namespace { diff --git a/content/browser/ml/webnn/dml/readback_resource.cc b/content/browser/ml/webnn/dml/readback_resource.cc index af20235432e521..9760bdcff3151f 100644 --- a/content/browser/ml/webnn/dml/readback_resource.cc +++ b/content/browser/ml/webnn/dml/readback_resource.cc @@ -8,8 +8,6 @@ #include "content/browser/ml/webnn/dml/execution_context.h" -#pragma optimize("", off) // TODO:::DELETE - namespace content::webnn { ReadbackResource::ReadbackResource(ExecutionContext* execution_context) diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc index 2357ceea84c3ba..b4f2108739cde0 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc @@ -14,8 +14,6 @@ #include "third_party/blink/renderer/platform/heap/collection_support/heap_deque.h" #include "third_party/blink/renderer/platform/heap/collection_support/heap_hash_set.h" -#pragma optimize("", off) // TODO:::DELETE - namespace blink { namespace { diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc index d532398f3e428a..88d817709878c9 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc @@ -47,8 +47,6 @@ #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.h" #endif -#pragma optimize("", off) // TODO:::DELETE - namespace blink { namespace { diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc b/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc index 0e7ca288ecc84f..849b9e2186c95b 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc @@ -8,8 +8,6 @@ #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h" -#pragma optimize("", off) // TODO:::DELETE - namespace blink { namespace { diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc index ce461f036cc0b3..35d2f0c86ae3e5 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc @@ -21,8 +21,6 @@ #include -#pragma optimize("", off) // TODO:::DELETE - namespace blink { namespace { diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc index 2883e09a531d65..bdc71854267e0f 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc @@ -29,8 +29,6 @@ #include "third_party/blink/renderer/platform/bindings/exception_code.h" #include "third_party/blink/renderer/platform/bindings/exception_state.h" -#pragma optimize("", off) // TODO:::DELETE - namespace blink { #define DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(/*MLOperator* */ ml_operator, \ From 7b87a04b6725bcbcce544b4338bfa1987934ac12 Mon Sep 17 00:00:00 2001 From: Ningxin Hu Date: Mon, 22 May 2023 17:53:20 +0800 Subject: [PATCH 2/3] Add some trace events --- .../browser/ml/webnn/dml/command_recorder.cc | 4 + .../browser/ml/webnn/dml/graph_dml_impl.cc | 5 +- .../browser/ml/webnn/dml/readback_resource.cc | 3 + .../browser/ml/webnn/dml/upload_resource.cc | 4 + .../renderer/modules/ml/webnn/mojo_graph.cc | 87 +++++++++++-------- 5 files changed, 64 insertions(+), 39 deletions(-) diff --git a/content/browser/ml/webnn/dml/command_recorder.cc b/content/browser/ml/webnn/dml/command_recorder.cc index 63e66f2e0eedc5..a54179317b1a18 100644 --- a/content/browser/ml/webnn/dml/command_recorder.cc +++ b/content/browser/ml/webnn/dml/command_recorder.cc @@ -4,6 +4,8 @@ #include "content/browser/ml/webnn/dml/command_recorder.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/adapter_dml.h" #include "content/browser/ml/webnn/dml/execution_resources.h" @@ -65,6 +67,7 @@ HRESULT CommandRecorder::InitializeGraph( GraphDMLImpl* graph, IDMLCompiledOperator* compiled_operator, const DML_BINDING_DESC& input_array_binding) { + TRACE_EVENT0("gpu", "CommandRecorder::InitializeGraph"); // Reset the initializer to reference the compiled operator. IDMLCompiledOperator* ops[] = {compiled_operator}; HRESULT hr = operator_initializer_->Reset(ARRAYSIZE(ops), ops); @@ -164,6 +167,7 @@ HRESULT CommandRecorder::ExecuteGraph( IDMLCompiledOperator* compiled_operator, const std::vector& input_bindings, const std::vector& output_bindings) { + TRACE_EVENT0("gpu", "CommandRecorder::ExecuteGraph"); DCHECK(mBindingTable != nullptr); // Bind and execute the operator on the GPU. // Reset the binding table to bind for the operator we want to execute (it diff --git a/content/browser/ml/webnn/dml/graph_dml_impl.cc b/content/browser/ml/webnn/dml/graph_dml_impl.cc index 0c4a450424c8b8..d597471f59fcfc 100644 --- a/content/browser/ml/webnn/dml/graph_dml_impl.cc +++ b/content/browser/ml/webnn/dml/graph_dml_impl.cc @@ -4,9 +4,11 @@ #include "content/browser/ml/webnn/dml/graph_dml_impl.h" +#include "base/containers/span.h" #include "base/logging.h" #include "base/memory/ptr_util.h" -#include "base/containers/span.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" #include "content/browser/ml/webnn/dml/execution_resources.h" #include "content/browser/ml/webnn/dml/graph_dml_impl.h" @@ -1855,6 +1857,7 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { void GraphDMLImpl::Compute(NamedResourcesPtr named_inputs, ComputeCallback callback) { + TRACE_EVENT0("gpu", "GraphDMLImpl::Compute"); ExecutionResources* execution_resources = execution_context_->GetExecutionResources(); ID3D12Resource* inputs_resource = diff --git a/content/browser/ml/webnn/dml/readback_resource.cc b/content/browser/ml/webnn/dml/readback_resource.cc index 9760bdcff3151f..0d5aa1cf42ac3f 100644 --- a/content/browser/ml/webnn/dml/readback_resource.cc +++ b/content/browser/ml/webnn/dml/readback_resource.cc @@ -6,6 +6,8 @@ #include +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" namespace content::webnn { @@ -42,6 +44,7 @@ HRESULT ReadbackResource::InitializeResource( // Readback inference result from GPU that is stored in named_outputs. HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs, ID3D12Resource* src_resource) { + TRACE_EVENT0("gpu", "ReadbackResource::ReadResourceFromGpu"); // Copy buffer from GPU resource to CPU data. execution_context_->CopyBufferRegion(readback_resource_->GetResource(), src_resource, outputs_resource_size_, diff --git a/content/browser/ml/webnn/dml/upload_resource.cc b/content/browser/ml/webnn/dml/upload_resource.cc index 00bd035bf28580..45d5c8f474c2f7 100644 --- a/content/browser/ml/webnn/dml/upload_resource.cc +++ b/content/browser/ml/webnn/dml/upload_resource.cc @@ -6,6 +6,8 @@ #include +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" namespace content::webnn { @@ -60,6 +62,7 @@ UploadResource::~UploadResource() = default; // need to transition. HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, ConstantsInfoPtr& constants_info) { + TRACE_EVENT0("gpu", "UploadResource::UploadConstants"); base::ReadOnlySharedMemoryRegion& shared_memory_region = constants_info->shared_memory; size_t constants_byte_length = shared_memory_region.GetSize(); @@ -80,6 +83,7 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, NamedResourcesPtr& named_inputs) { + TRACE_EVENT0("gpu", "UploadResource::UploadInputs"); base::ReadOnlySharedMemoryRegion& shared_memory_region = named_inputs->shared_memory; size_t inputs_byte_length = shared_memory_region.GetSize(); diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc index 35d2f0c86ae3e5..20ffef7a682afd 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc @@ -4,6 +4,8 @@ #include "third_party/blink/renderer/modules/ml/webnn/mojo_graph.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "mojo/public/cpp/bindings/pending_remote.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor.h" @@ -277,6 +279,7 @@ void MojoGraph::ComputeAsyncImpl(const MLNamedArrayBufferViews& inputs, void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, const MLNamedArrayBufferViews& outputs, ExceptionState& exception_state) { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl"); if (inputs.size() != input_resources_info_.size()) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, "The number of inputs is invalid."); @@ -284,24 +287,28 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, } auto named_inputs = ml::webnn::mojom::blink::NamedResources::New(), named_outputs = ml::webnn::mojom::blink::NamedResources::New(); - for (const auto& input : inputs) { - String error_message; - auto* input_array_buffer_view = input.second.Get(); - if (input_array_buffer_view == nullptr) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - error_message); + { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyInputs"); + for (const auto& input : inputs) { + String error_message; + auto* input_array_buffer_view = input.second.Get(); + if (input_array_buffer_view == nullptr) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + } + const String& input_name = input.first; + auto memory_info = ml::webnn::mojom::blink::MemoryInfo::New(); + memory_info->byte_offset = inputs_byte_offset_.at(input_name); + memory_info->byte_length = + input_resources_info_.at(input_name).byte_length; + uint8_t* address = inputs_shm_region_.mapping.GetMemoryAs() + + memory_info->byte_offset; + memcpy(address, input_array_buffer_view->BaseAddressMaybeShared(), + input_array_buffer_view->byteLength()); + named_inputs->resources.insert(input_name, std::move(memory_info)); } - const String& input_name = input.first; - auto memory_info = ml::webnn::mojom::blink::MemoryInfo::New(); - memory_info->byte_offset = inputs_byte_offset_.at(input_name); - memory_info->byte_length = input_resources_info_.at(input_name).byte_length; - uint8_t* address = inputs_shm_region_.mapping.GetMemoryAs() + - memory_info->byte_offset; - memcpy(address, input_array_buffer_view->BaseAddressMaybeShared(), - input_array_buffer_view->byteLength()); - named_inputs->resources.insert(input_name, std::move(memory_info)); + named_inputs->shared_memory = inputs_shm_region_.region.Duplicate(); } - named_inputs->shared_memory = inputs_shm_region_.region.Duplicate(); ComputeResult result; if (!remote_graph_->Compute(std::move(named_inputs), &result, &named_outputs)) { @@ -309,29 +316,33 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, "Failed to compute the graph."); return; }; - for (const auto& output : outputs) { - String error_message; - void* output_buffer_address = output.second->BaseAddressMaybeShared(); - if (output_buffer_address == nullptr) { - exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, - error_message); - return; - } - auto iter = named_outputs->resources.find(output.first); - if (iter == named_outputs->resources.end()) { - exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, - "Failed to get result for the output."); - return; + { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyOutputs"); + for (const auto& output : outputs) { + String error_message; + void* output_buffer_address = output.second->BaseAddressMaybeShared(); + if (output_buffer_address == nullptr) { + exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, + error_message); + return; + } + auto iter = named_outputs->resources.find(output.first); + if (iter == named_outputs->resources.end()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kOperationError, + "Failed to get result for the output."); + return; + } + MemoryInfoPtr memory_info = std::move(iter->value); + base::ReadOnlySharedMemoryRegion& shared_memory_region = + named_outputs->shared_memory; + DCHECK(shared_memory_region.IsValid()); + size_t byte_length = base::checked_cast(memory_info->byte_length); + base::ReadOnlySharedMemoryMapping shared_memory_mapping = + shared_memory_region.MapAt(memory_info->byte_offset, byte_length); + memcpy(output_buffer_address, + shared_memory_mapping.GetMemoryAs(), byte_length); } - MemoryInfoPtr memory_info = std::move(iter->value); - base::ReadOnlySharedMemoryRegion& shared_memory_region = - named_outputs->shared_memory; - DCHECK(shared_memory_region.IsValid()); - size_t byte_length = base::checked_cast(memory_info->byte_length); - base::ReadOnlySharedMemoryMapping shared_memory_mapping = - shared_memory_region.MapAt(memory_info->byte_offset, byte_length); - memcpy(output_buffer_address, shared_memory_mapping.GetMemoryAs(), - byte_length); } } From dca067de7be17c79f36a01463506a91b76287bfc Mon Sep 17 00:00:00 2001 From: Ningxin Hu Date: Sun, 21 May 2023 23:01:26 +0800 Subject: [PATCH 3/3] Optimize shared memory mapping and copy for input and output buffers 1. Avoid mapping shared memory for every compute 2. Copy input/output and uploading/readback buffers in one big chunk --- .../browser/ml/webnn/dml/readback_resource.cc | 21 ++++------ .../browser/ml/webnn/dml/upload_resource.cc | 41 +++++++++---------- .../browser/ml/webnn/dml/upload_resource.h | 1 + .../renderer/modules/ml/webnn/mojo_graph.cc | 12 +++--- .../renderer/modules/ml/webnn/mojo_graph.h | 1 + 5 files changed, 35 insertions(+), 41 deletions(-) diff --git a/content/browser/ml/webnn/dml/readback_resource.cc b/content/browser/ml/webnn/dml/readback_resource.cc index 0d5aa1cf42ac3f..afa4360be72d7a 100644 --- a/content/browser/ml/webnn/dml/readback_resource.cc +++ b/content/browser/ml/webnn/dml/readback_resource.cc @@ -54,30 +54,25 @@ HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs, execution_context_->WaitForSignal(); execution_context_->ReleaseCompletedResources(); - D3D12_RANGE tensorBufferRange{0, outputs_resource_size_}; - int8_t* readBackBuffer; + D3D12_RANGE read_range{0, outputs_resource_size_}; + int8_t* readback_buffer; HRESULT hr = readback_resource_->Map( - 0, &tensorBufferRange, reinterpret_cast(&readBackBuffer)); + 0, &read_range, reinterpret_cast(&readback_buffer)); if (FAILED(hr)) { return hr; } + uint8_t* address = outputs_shm_region_.mapping.GetMemoryAs(); + memcpy(address, readback_buffer, outputs_resource_size_); + readback_resource_->Unmap(0, nullptr); for (auto& [name, memory_info] : outputs_info_map_) { auto mojo_memory_info = ml::webnn::mojom::MemoryInfo::New(); - size_t byte_offset = memory_info.byte_offset; - size_t byte_length = memory_info.byte_length; - mojo_memory_info->byte_offset = byte_offset; - mojo_memory_info->byte_length = byte_length; + mojo_memory_info->byte_offset = memory_info.byte_offset; + mojo_memory_info->byte_length = memory_info.byte_length; named_outputs->resources[name] = std::move(mojo_memory_info); - - std::vector output_buffer(byte_length); - uint8_t* address = - outputs_shm_region_.mapping.GetMemoryAs() + byte_offset; - memcpy(address, readBackBuffer + byte_offset, byte_length); } named_outputs->shared_memory = outputs_shm_region_.region.Duplicate(); - readback_resource_->Unmap(0, nullptr); return S_OK; } diff --git a/content/browser/ml/webnn/dml/upload_resource.cc b/content/browser/ml/webnn/dml/upload_resource.cc index 45d5c8f474c2f7..cc6160711073d4 100644 --- a/content/browser/ml/webnn/dml/upload_resource.cc +++ b/content/browser/ml/webnn/dml/upload_resource.cc @@ -16,13 +16,12 @@ namespace { using ml::webnn::mojom::MemoryInfoPtr; -template HRESULT UploadResourceToGpu( ExecutionContext* execution_context, ID3D12Resource* dst_resource, ID3D12Resource* src_resource, - base::ReadOnlySharedMemoryRegion& shared_memory_region, - T& named_inputs) { + base::ReadOnlySharedMemoryMapping& shared_memory_mapping, + size_t byte_length) { // Map the upload heap and copy the source data into it. A null pointer // indicates the entire subresource might be read by the CPU. void* upload_data = nullptr; @@ -30,22 +29,12 @@ HRESULT UploadResourceToGpu( if (FAILED(hr)) { return hr; } - - for (auto& [_, memory_info] : named_inputs) { - uint64_t byte_length = memory_info->byte_length; - uint64_t byte_offset = memory_info->byte_offset; - DCHECK(byte_offset % DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT == 0); - DCHECK(shared_memory_region.IsValid()); - base::ReadOnlySharedMemoryMapping shared_memory_mapping = - shared_memory_region.MapAt(memory_info->byte_offset, byte_length); - memcpy(static_cast(upload_data) + memory_info->byte_offset, - shared_memory_mapping.GetMemoryAs(), byte_length); - } + memcpy(static_cast(upload_data), + shared_memory_mapping.GetMemoryAs(), byte_length); src_resource->Unmap(0, nullptr); // Copy from the upload heap into the destination resource - execution_context->CopyBufferRegion(dst_resource, src_resource, - shared_memory_region.GetSize(), + execution_context->CopyBufferRegion(dst_resource, src_resource, byte_length, D3D12_RESOURCE_STATE_COPY_DEST); return S_OK; @@ -66,6 +55,10 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, base::ReadOnlySharedMemoryRegion& shared_memory_region = constants_info->shared_memory; size_t constants_byte_length = shared_memory_region.GetSize(); + if (!shm_mapping_.IsValid()) { + shm_mapping_ = shared_memory_region.Map(); + DCHECK(shm_mapping_.IsValid()); + } HRESULT hr = S_OK; if (upload_resource_ == nullptr) { @@ -76,9 +69,9 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, } DCHECK(upload_resource_ != nullptr); - return UploadResourceToGpu>( - execution_context_, dst_resource, upload_resource_->GetResource(), - shared_memory_region, constants_info->memory_info); + return UploadResourceToGpu(execution_context_, dst_resource, + upload_resource_->GetResource(), shm_mapping_, + constants_byte_length); } HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, @@ -87,6 +80,10 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, base::ReadOnlySharedMemoryRegion& shared_memory_region = named_inputs->shared_memory; size_t inputs_byte_length = shared_memory_region.GetSize(); + if (!shm_mapping_.IsValid()) { + shm_mapping_ = shared_memory_region.Map(); + DCHECK(shm_mapping_.IsValid()); + } HRESULT hr = S_OK; if (upload_resource_ == nullptr) { @@ -97,9 +94,9 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, } DCHECK(upload_resource_ != nullptr); - return UploadResourceToGpu>( - execution_context_, dst_resource, upload_resource_->GetResource(), - shared_memory_region, named_inputs->resources); + return UploadResourceToGpu(execution_context_, dst_resource, + upload_resource_->GetResource(), shm_mapping_, + inputs_byte_length); } // Create entire memory for uploading resource that will be uploaded piece by diff --git a/content/browser/ml/webnn/dml/upload_resource.h b/content/browser/ml/webnn/dml/upload_resource.h index 1342df8ca9b3c8..64ae8a2d951873 100644 --- a/content/browser/ml/webnn/dml/upload_resource.h +++ b/content/browser/ml/webnn/dml/upload_resource.h @@ -35,6 +35,7 @@ class UploadResource final { ExecutionContext* execution_context_; ComPtr upload_resource_; + base::ReadOnlySharedMemoryMapping shm_mapping_; }; } // namespace content::webnn diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc index 20ffef7a682afd..d90ecbfc739df7 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc @@ -318,6 +318,9 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, }; { TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyOutputs"); + if (!outputs_shm_mapping_.IsValid()) { + outputs_shm_mapping_ = named_outputs->shared_memory.Map(); + } for (const auto& output : outputs) { String error_message; void* output_buffer_address = output.second->BaseAddressMaybeShared(); @@ -334,14 +337,11 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, return; } MemoryInfoPtr memory_info = std::move(iter->value); - base::ReadOnlySharedMemoryRegion& shared_memory_region = - named_outputs->shared_memory; - DCHECK(shared_memory_region.IsValid()); + size_t byte_offset = base::checked_cast(memory_info->byte_offset); size_t byte_length = base::checked_cast(memory_info->byte_length); - base::ReadOnlySharedMemoryMapping shared_memory_mapping = - shared_memory_region.MapAt(memory_info->byte_offset, byte_length); memcpy(output_buffer_address, - shared_memory_mapping.GetMemoryAs(), byte_length); + outputs_shm_mapping_.GetMemoryAs() + byte_offset, + byte_length); } } } diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h index a0ffdda736e3b1..ebe132df1338ea 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h @@ -66,6 +66,7 @@ class MojoGraph : public MLGraph { // The map of input name and input data offset. HashMap inputs_byte_offset_; base::MappedReadOnlyRegion inputs_shm_region_; + base::ReadOnlySharedMemoryMapping outputs_shm_mapping_; HeapMojoRemote remote_graph_; };