Skip to content

Commit

Permalink
Merge pull request #6 from huningxin/fix_output_index
Browse files Browse the repository at this point in the history
Fix using incorrect index for output buffer bindings
  • Loading branch information
fdwr authored May 17, 2023
2 parents 1e09d98 + 1d6a224 commit 440e80b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
7 changes: 4 additions & 3 deletions content/browser/ml/webnn/dml/graph_desc_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ void GraphDescBuilder::AddOutputEdge(NodeOutput* node_output,
output_edge.GraphOutputIndex = output_index;
graph_desc_.output_edges.push_back(output_edge);

named_outputs_[name] =
node_output->GetTensorDesc().GetTotalTensorSizeInBytes();
named_outputs_[name] = {
.index = output_index,
.byte_length = node_output->GetTensorDesc().GetTotalTensorSizeInBytes()};
}

ComPtr<IDMLCompiledOperator> GraphDescBuilder::Compile(
Expand Down Expand Up @@ -171,7 +172,7 @@ std::vector<InputNode>& GraphDescBuilder::GetInputNodes() {
return input_nodes_;
}

std::map<std::string, size_t>& GraphDescBuilder::GetNamedOutputs() {
std::map<std::string, OutputInfo>& GraphDescBuilder::GetNamedOutputs() {
return named_outputs_;
}

Expand Down
12 changes: 9 additions & 3 deletions content/browser/ml/webnn/dml/graph_desc_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

namespace content::webnn {

struct OutputInfo {
size_t index;
size_t byte_length;
};

class GraphDescBuilder final {
public:
explicit GraphDescBuilder(ComPtr<IDMLDevice> device);
Expand All @@ -31,7 +36,7 @@ class GraphDescBuilder final {
ComPtr<IDMLCompiledOperator> Compile(DML_EXECUTION_FLAGS flags);

std::vector<InputNode>& GetInputNodes();
std::map<std::string, size_t>& GetNamedOutputs();
std::map<std::string, OutputInfo>& GetNamedOutputs();

private:
struct GraphDesc {
Expand All @@ -46,12 +51,13 @@ class GraphDescBuilder final {

// The inputs node include inputs for execution and constant for
// initialization because Both of them are inputs for DirectML Graph.
// The input node index is same as the offset in this vector.
std::vector<InputNode> input_nodes_;
// The operator nodes hold a reference of IDMLOperator to be used for
// GraphDesc.nodes
std::vector<OperatorNode> operator_nodes_;
// The output name and byte length mapping.
std::map<std::string, size_t> named_outputs_;
// The output name to output index and byte length mapping.
std::map<std::string, OutputInfo> named_outputs_;
GraphDesc graph_desc_;
ComPtr<IDMLDevice> device_;
};
Expand Down
38 changes: 18 additions & 20 deletions content/browser/ml/webnn/dml/graph_dml_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1887,26 +1887,25 @@ void GraphDMLImpl::Compute(NamedResourcesPtr named_inputs,
outputs_resource = execution_resources->Allocate(
ResourceType::kOutput, outputs_resource_size, this);
}
auto& output_length_map = graph_desc_builder_->GetNamedOutputs();
std::vector<DML_BINDING_DESC> output_binding_desc(output_length_map.size());
auto& output_info_map = graph_desc_builder_->GetNamedOutputs();
std::vector<DML_BINDING_DESC> output_binding_desc(output_info_map.size());
// The sort of the outputs from Graph Compute is different from the
// outputs from Graph Build, so the offset need to be found the correct output
// with name to read back from GPU buffer.
base::flat_map<std::string, DML_BUFFER_BINDING> output_buffer_binding;
// Reseve the map capacity to avoid reallocation.
output_buffer_binding.reserve(output_length_map.size());
output_buffer_binding.reserve(output_info_map.size());
uint64_t aligned_offset = 0;
size_t i = 0;
for (auto& [name, byte_length] : output_length_map) {
for (auto& [name, output_info] : output_info_map) {
DML_BUFFER_BINDING buffer_binding;
buffer_binding.Buffer = outputs_resource;
buffer_binding.Offset = aligned_offset;
buffer_binding.SizeInBytes = byte_length;
buffer_binding.SizeInBytes = output_info.byte_length;
output_buffer_binding[name] = buffer_binding;
output_binding_desc[i] = {DML_BINDING_TYPE_BUFFER,
&output_buffer_binding[name]};
aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
++i;
output_binding_desc[output_info.index] = {DML_BINDING_TYPE_BUFFER,
&output_buffer_binding[name]};
aligned_offset +=
Align(output_info.byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
}

execution_context_->ExecuteGraph(this, mCompiledOperator.Get(),
Expand Down Expand Up @@ -1963,26 +1962,25 @@ bool GraphDMLImpl::Compute(NamedResourcesPtr named_inputs,
outputs_resource = execution_resources->Allocate(
ResourceType::kOutput, outputs_resource_size, this);
}
auto& output_length_map = graph_desc_builder_->GetNamedOutputs();
std::vector<DML_BINDING_DESC> output_binding_desc(output_length_map.size());
auto& output_info_map = graph_desc_builder_->GetNamedOutputs();
std::vector<DML_BINDING_DESC> output_binding_desc(output_info_map.size());
// The sort of the outputs from Graph Compute is different from the
// outputs from Graph Build, so the offset need to be found the correct output
// with name to read back from GPU buffer.
base::flat_map<std::string, DML_BUFFER_BINDING> output_buffer_binding;
// Reseve the map capacity to avoid reallocation.
output_buffer_binding.reserve(output_length_map.size());
output_buffer_binding.reserve(output_info_map.size());
uint64_t aligned_offset = 0;
size_t i = 0;
for (auto& [name, byte_length] : output_length_map) {
for (auto& [name, output_info] : output_info_map) {
DML_BUFFER_BINDING buffer_binding;
buffer_binding.Buffer = outputs_resource;
buffer_binding.Offset = aligned_offset;
buffer_binding.SizeInBytes = byte_length;
buffer_binding.SizeInBytes = output_info.byte_length;
output_buffer_binding[name] = buffer_binding;
output_binding_desc[i] = {DML_BINDING_TYPE_BUFFER,
&output_buffer_binding[name]};
aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
++i;
output_binding_desc[output_info.index] = {DML_BINDING_TYPE_BUFFER,
&output_buffer_binding[name]};
aligned_offset +=
Align(output_info.byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
}

execution_context_->ExecuteGraph(this, mCompiledOperator.Get(),
Expand Down
9 changes: 5 additions & 4 deletions content/browser/ml/webnn/dml/readback_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ ReadbackResource::ReadbackResource(ExecutionContext* execution_context)
ReadbackResource::~ReadbackResource() = default;

HRESULT ReadbackResource::InitializeResource(
std::map<std::string, size_t>& named_outputs) {
std::map<std::string, OutputInfo>& named_outputs) {
uint64_t aligned_offset = 0;
for (auto& [name, byte_length] : named_outputs) {
for (auto& [name, output_info] : named_outputs) {
MemoryInfo memory_info = {};
memory_info.byte_offset = aligned_offset;
memory_info.byte_length = byte_length;
memory_info.byte_length = output_info.byte_length;
outputs_info_map_[name] = memory_info;

// Only offset need to be algnement, the byte length keep original value.
aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
aligned_offset +=
Align(output_info.byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
}
outputs_resource_size_ = aligned_offset;
outputs_shm_region_ =
Expand Down
3 changes: 2 additions & 1 deletion content/browser/ml/webnn/dml/readback_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "DirectML.h"
#include "components/ml/mojom/webnn_graph.mojom.h"
#include "content/browser/ml/webnn/dml/gpgmm_d3d12.h"
#include "content/browser/ml/webnn/dml/graph_desc_builder.h"
#include "content/browser/ml/webnn/dml/utils_dml.h"

namespace content::webnn {
Expand All @@ -25,7 +26,7 @@ class ReadbackResource final {
explicit ReadbackResource(ExecutionContext* execution_context);
~ReadbackResource();

HRESULT InitializeResource(std::map<std::string, size_t>& named_outputs);
HRESULT InitializeResource(std::map<std::string, OutputInfo>& named_outputs);
HRESULT ReadResourceFromGpu(NamedResourcesPtr& named_outputs,
ID3D12Resource* src_resource);
size_t GetOutputsResourceSize() const;
Expand Down

0 comments on commit 440e80b

Please sign in to comment.