Skip to content

Commit

Permalink
Optimize shared memory mapping and copy for input and output buffers
Browse files Browse the repository at this point in the history
1. Avoid mapping shared memory for every compute
2. Copy input/output and uploading/readback buffers in one big chunk
  • Loading branch information
huningxin committed May 22, 2023
1 parent 7b87a04 commit dca067d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 41 deletions.
21 changes: 8 additions & 13 deletions content/browser/ml/webnn/dml/readback_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,25 @@ HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs,
execution_context_->WaitForSignal();
execution_context_->ReleaseCompletedResources();

D3D12_RANGE tensorBufferRange{0, outputs_resource_size_};
int8_t* readBackBuffer;
D3D12_RANGE read_range{0, outputs_resource_size_};
int8_t* readback_buffer;
HRESULT hr = readback_resource_->Map(
0, &tensorBufferRange, reinterpret_cast<void**>(&readBackBuffer));
0, &read_range, reinterpret_cast<void**>(&readback_buffer));
if (FAILED(hr)) {
return hr;
}
uint8_t* address = outputs_shm_region_.mapping.GetMemoryAs<uint8_t>();
memcpy(address, readback_buffer, outputs_resource_size_);
readback_resource_->Unmap(0, nullptr);

for (auto& [name, memory_info] : outputs_info_map_) {
auto mojo_memory_info = ml::webnn::mojom::MemoryInfo::New();
size_t byte_offset = memory_info.byte_offset;
size_t byte_length = memory_info.byte_length;
mojo_memory_info->byte_offset = byte_offset;
mojo_memory_info->byte_length = byte_length;
mojo_memory_info->byte_offset = memory_info.byte_offset;
mojo_memory_info->byte_length = memory_info.byte_length;
named_outputs->resources[name] = std::move(mojo_memory_info);

std::vector<uint8_t> output_buffer(byte_length);
uint8_t* address =
outputs_shm_region_.mapping.GetMemoryAs<uint8_t>() + byte_offset;
memcpy(address, readBackBuffer + byte_offset, byte_length);
}
named_outputs->shared_memory = outputs_shm_region_.region.Duplicate();

readback_resource_->Unmap(0, nullptr);
return S_OK;
}

Expand Down
41 changes: 19 additions & 22 deletions content/browser/ml/webnn/dml/upload_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,25 @@ namespace {

using ml::webnn::mojom::MemoryInfoPtr;

template <typename T>
HRESULT UploadResourceToGpu(
ExecutionContext* execution_context,
ID3D12Resource* dst_resource,
ID3D12Resource* src_resource,
base::ReadOnlySharedMemoryRegion& shared_memory_region,
T& named_inputs) {
base::ReadOnlySharedMemoryMapping& shared_memory_mapping,
size_t byte_length) {
// Map the upload heap and copy the source data into it. A null pointer
// indicates the entire subresource might be read by the CPU.
void* upload_data = nullptr;
HRESULT hr = src_resource->Map(0, nullptr, &upload_data);
if (FAILED(hr)) {
return hr;
}

for (auto& [_, memory_info] : named_inputs) {
uint64_t byte_length = memory_info->byte_length;
uint64_t byte_offset = memory_info->byte_offset;
DCHECK(byte_offset % DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT == 0);
DCHECK(shared_memory_region.IsValid());
base::ReadOnlySharedMemoryMapping shared_memory_mapping =
shared_memory_region.MapAt(memory_info->byte_offset, byte_length);
memcpy(static_cast<byte*>(upload_data) + memory_info->byte_offset,
shared_memory_mapping.GetMemoryAs<uint8_t>(), byte_length);
}
memcpy(static_cast<byte*>(upload_data),
shared_memory_mapping.GetMemoryAs<uint8_t>(), byte_length);
src_resource->Unmap(0, nullptr);

// Copy from the upload heap into the destination resource
execution_context->CopyBufferRegion(dst_resource, src_resource,
shared_memory_region.GetSize(),
execution_context->CopyBufferRegion(dst_resource, src_resource, byte_length,
D3D12_RESOURCE_STATE_COPY_DEST);

return S_OK;
Expand All @@ -66,6 +55,10 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource,
base::ReadOnlySharedMemoryRegion& shared_memory_region =
constants_info->shared_memory;
size_t constants_byte_length = shared_memory_region.GetSize();
if (!shm_mapping_.IsValid()) {
shm_mapping_ = shared_memory_region.Map();
DCHECK(shm_mapping_.IsValid());
}

HRESULT hr = S_OK;
if (upload_resource_ == nullptr) {
Expand All @@ -76,9 +69,9 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource,
}
DCHECK(upload_resource_ != nullptr);

return UploadResourceToGpu<base::flat_map<UINT64, MemoryInfoPtr>>(
execution_context_, dst_resource, upload_resource_->GetResource(),
shared_memory_region, constants_info->memory_info);
return UploadResourceToGpu(execution_context_, dst_resource,
upload_resource_->GetResource(), shm_mapping_,
constants_byte_length);
}

HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource,
Expand All @@ -87,6 +80,10 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource,
base::ReadOnlySharedMemoryRegion& shared_memory_region =
named_inputs->shared_memory;
size_t inputs_byte_length = shared_memory_region.GetSize();
if (!shm_mapping_.IsValid()) {
shm_mapping_ = shared_memory_region.Map();
DCHECK(shm_mapping_.IsValid());
}

HRESULT hr = S_OK;
if (upload_resource_ == nullptr) {
Expand All @@ -97,9 +94,9 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource,
}
DCHECK(upload_resource_ != nullptr);

return UploadResourceToGpu<base::flat_map<std::string, MemoryInfoPtr>>(
execution_context_, dst_resource, upload_resource_->GetResource(),
shared_memory_region, named_inputs->resources);
return UploadResourceToGpu(execution_context_, dst_resource,
upload_resource_->GetResource(), shm_mapping_,
inputs_byte_length);
}

// Create entire memory for uploading resource that will be uploaded piece by
Expand Down
1 change: 1 addition & 0 deletions content/browser/ml/webnn/dml/upload_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class UploadResource final {

ExecutionContext* execution_context_;
ComPtr<gpgmm::d3d12::ResourceAllocation> upload_resource_;
base::ReadOnlySharedMemoryMapping shm_mapping_;
};

} // namespace content::webnn
Expand Down
12 changes: 6 additions & 6 deletions third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs,
};
{
TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyOutputs");
if (!outputs_shm_mapping_.IsValid()) {
outputs_shm_mapping_ = named_outputs->shared_memory.Map();
}
for (const auto& output : outputs) {
String error_message;
void* output_buffer_address = output.second->BaseAddressMaybeShared();
Expand All @@ -334,14 +337,11 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs,
return;
}
MemoryInfoPtr memory_info = std::move(iter->value);
base::ReadOnlySharedMemoryRegion& shared_memory_region =
named_outputs->shared_memory;
DCHECK(shared_memory_region.IsValid());
size_t byte_offset = base::checked_cast<size_t>(memory_info->byte_offset);
size_t byte_length = base::checked_cast<size_t>(memory_info->byte_length);
base::ReadOnlySharedMemoryMapping shared_memory_mapping =
shared_memory_region.MapAt(memory_info->byte_offset, byte_length);
memcpy(output_buffer_address,
shared_memory_mapping.GetMemoryAs<uint8_t>(), byte_length);
outputs_shm_mapping_.GetMemoryAs<uint8_t>() + byte_offset,
byte_length);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions third_party/blink/renderer/modules/ml/webnn/mojo_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class MojoGraph : public MLGraph {
// The map of input name and input data offset.
HashMap<String, size_t> inputs_byte_offset_;
base::MappedReadOnlyRegion inputs_shm_region_;
base::ReadOnlySharedMemoryMapping outputs_shm_mapping_;

HeapMojoRemote<ml::webnn::mojom::blink::Graph> remote_graph_;
};
Expand Down

0 comments on commit dca067d

Please sign in to comment.