Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype additional WebNN operator implementations via DirectML #1

Open
wants to merge 72 commits into
base: dml_base
Choose a base branch
from
Open
Changes from 2 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
b53dc39
Prototype for Stable Diffusion #1. Need 5 more ops and cleanup.
fdwr Apr 5, 2023
413f29b
Prototype for stable diffusion #2.
fdwr Apr 10, 2023
0534ea3
Prototype for Stable Diffusion #3.
fdwr Apr 11, 2023
d38cec6
Prototype for Stable Diffusion #4.
fdwr Apr 12, 2023
4aa083a
Prototype for Stable Diffusion #5.
fdwr Apr 13, 2023
19aac03
Comment out pad and triangularMatrix which are not implemented
fdwr Apr 13, 2023
396cd7d
Fix argMin/max diagnostic name when validating axis
fdwr Apr 13, 2023
4d56c32
Enable DirectML debug layer accordingly when the D3D debug layer is e…
fdwr Apr 13, 2023
a88e205
Fix MatMul broadcasting
fdwr Apr 14, 2023
aaa569f
Fix copy pasta comments in ml graph builder
fdwr Apr 14, 2023
cdd91f7
Delete unused broadcasting code
fdwr Apr 14, 2023
d4b711c
Delete stale comments in DML graph and desc builder
fdwr Apr 15, 2023
03cfc24
Remove pragma optimize hacks temporarily inserted for easier debugging
fdwr Apr 15, 2023
01f4268
Add MeanVarianceNormalization, Reciprocal, LogicalNot, and stubs for …
fdwr May 3, 2023
4a456ca
Minor IDL comments
fdwr May 3, 2023
06ae130
Add Split. Fix empty dimensions issues. Fix Resample bug. Restore -1 …
fdwr May 9, 2023
c489489
Retry DMLCreateDevice again without the optional DML debug layer if i…
fdwr May 11, 2023
a055ac0
Conv2dTranspose temporary scaffolding
fdwr May 11, 2023
7404899
Add int64
fdwr May 11, 2023
47a78ad
Complete GetBytesOfDataType switch statement with all data types for …
fdwr May 11, 2023
56097e5
Add reshape todo comment linking to spec issue
fdwr May 11, 2023
88799be
Add ConvTranspose
fdwr May 12, 2023
68d6a94
Reallocate DescriptorHeap if the count increases when executing an op
huningxin May 15, 2023
1ccfe2a
Merge pull request #2 from huningxin/fix_desc_heap
fdwr May 15, 2023
df4df46
Fix adding duplicated graph inputs
huningxin May 16, 2023
0f7b9e3
Merge pull request #3 from huningxin/fix_dup_inputs
fdwr May 16, 2023
a56316f
Fix crash issue when graph has multiple outputs
huningxin May 16, 2023
de5b0a3
Merge pull request #4 from huningxin/fix_multi_outputs
fdwr May 16, 2023
79c6cd2
Fix Conv2d filter issue when NHWC
fdwr May 17, 2023
2cd8f90
Conv enum cleanup. Delete commented code, and update static asserts.
fdwr May 17, 2023
53f3a9c
Fix XNNPack Conv2d
fdwr May 17, 2023
44549f2
Shrink the Gather's output shape
mingmingtasd May 17, 2023
1d6a224
Fix using incorrect index for output buffer bindings
huningxin May 17, 2023
1e09d98
Merge pull request #5 from mingmingtasd/dml_sd
fdwr May 17, 2023
440e80b
Merge pull request #6 from huningxin/fix_output_index
fdwr May 17, 2023
52fe11f
Tiny comment update for logging in GraphDescBuilder::Compile
fdwr May 17, 2023
d06abe3
Merge branch 'dml_sd' of https://github.com/fdwr/chromium-src-webnn-d…
fdwr May 17, 2023
0a83694
Fix issue caused by agrmin/argmax output shape
mingmingtasd May 18, 2023
ade7fa2
Merge pull request #7 from mingmingtasd/dml_sd
fdwr May 18, 2023
1df0a58
Ensure Connect tensor outputs use the original dimensions
fdwr May 18, 2023
00e510a
In Unsqueeze, check for duplicate axes
fdwr May 18, 2023
5691356
Fix the issue of calculating Resize's scales
mingmingtasd May 19, 2023
9527c1b
Fix flattenTo2d
fdwr May 19, 2023
02e9f7b
flattenTo2d reformat with clang format
fdwr May 19, 2023
cd1c8e5
Merge pull request #8 from mingmingtasd/dml_sd
fdwr May 19, 2023
06bba00
Fix resample2d scales calculation and mode setting issue
huningxin May 19, 2023
644e4e8
Merge pull request #9 from huningxin/fix_resample2d
fdwr May 19, 2023
1ca4871
Delete #pragma optimize("", off)
huningxin May 22, 2023
7b87a04
Add some trace events
huningxin May 22, 2023
dca067d
Optimize shared memory mapping and copy for input and output buffers
huningxin May 21, 2023
1b387fe
Merge pull request #10 from huningxin/perf_optimize
fdwr May 23, 2023
ad4446d
Count operands index from 1
mingmingtasd Jun 9, 2023
120ff12
Merge pull request #11 from mingmingtasd/dml_sd
fdwr Jun 11, 2023
d56181d
Fix read AV issue with more than 2GB memory usage by checking failure…
fdwr Jun 11, 2023
9da456d
Merge branch 'dml_sd' of https://github.com/fdwr/chromium-src-webnn-d…
fdwr Jun 11, 2023
28c36b5
Apply CommandQueue::WaitAsync() to not block GPU main thread
mingmingtasd May 31, 2023
03d9dcc
Post IDMLDevice1::CompileGraph into other threads
mingmingtasd Jun 21, 2023
dafa9e5
Apply suggestions from code review
fdwr Jul 5, 2023
81c2e6f
Merge pull request #12 from mingmingtasd/waitAsync
fdwr Jul 5, 2023
7829be6
Merge pull request #13 from mingmingtasd/compile
fdwr Jul 5, 2023
0afc912
Add placeholder stubs for all remaining v1 ops. Implement abs, neg, l…
fdwr Jul 11, 2023
5035d66
Fix PoolL2 and PoolMax
fdwr Jul 13, 2023
351782f
Add activation two-parameter operators to mojo
fdwr Jul 15, 2023
7748c83
Implement all activation ops in DML except hard swish
fdwr Jul 25, 2023
f7b557b
Add Pad
fdwr Jul 26, 2023
8b8e593
Graph DML stub for hard swish and minor comment cleanup
fdwr Jul 26, 2023
441d579
Add meanVarianceNormalization optional precomputed tensors mean and v…
fdwr Aug 8, 2023
e24360d
Update graph builder IDL with quantized operators
fdwr Oct 4, 2023
fd1586f
Add NPU adapter support.
fdwr Oct 4, 2023
ec8d376
Add four quantized operators
fdwr Oct 4, 2023
a0fe3a5
DML graph_tensor_desc increase old 5D hard limit to 8D for SAM encode…
fdwr Oct 25, 2023
456852b
Merge pull request #14 from fdwr/npu-enabled
fdwr Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions content/browser/ml/webnn/dml/graph_dml_impl.cc
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#include "content/browser/ml/webnn/dml/upload_resource.h"
#include "mojo/public/c/system/types.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "base/task/thread_pool.h"

namespace content::webnn {

@@ -1802,8 +1803,26 @@ void GraphDMLImpl::Build(ModelInfoPtr model_info, BuildCallback callback) {
AddOutput(std::move(output->name), output->index);
}

// Finish the graph build.
// Post CompileGraph task to thread pool rather than run in GPU main thread to
// avoid blocking. The OnGraphCompiled task will run back on the current GPU main thread.
base::ThreadPool::PostTaskAndReply(
FROM_HERE,
base::BindOnce(&GraphDMLImpl::CompileGraph, base::Unretained(this)),
base::BindOnce(&GraphDMLImpl::OnGraphCompiled, base::Unretained(this),
std::move(callback), std::move(constant_resource),
std::move(constants_info)));
}

// Since IDMLDevice1::CompileGraph called in this method need long time to
// compile shaders (if not cached before), this method may block current thread.
void GraphDMLImpl::CompileGraph() {
mCompiledOperator = graph_desc_builder_->Compile(DML_EXECUTION_FLAG_NONE);
}

void GraphDMLImpl::OnGraphCompiled(
BuildCallback callback,
ComPtr<gpgmm::d3d12::ResourceAllocation> constant_resource,
ConstantsInfoPtr constants_info) {
if (!mCompiledOperator) {
std::move(callback).Run(BuildResult::kUnknownError);
return;
@@ -1832,8 +1851,8 @@ void GraphDMLImpl::Build(ModelInfoPtr model_info, BuildCallback callback) {
execution_context_->Flush();

execution_context_->WaitForSignal(
base::BindOnce(&GraphDMLImpl::OnWaitForBuildSignal, base::Unretained(this),
std::move(callback)));
base::BindOnce(&GraphDMLImpl::OnWaitForBuildSignal,
base::Unretained(this), std::move(callback)));
}

void GraphDMLImpl::OnWaitForBuildSignal(BuildCallback callback) {
5 changes: 5 additions & 0 deletions content/browser/ml/webnn/dml/graph_dml_impl.h
Original file line number Diff line number Diff line change
@@ -189,6 +189,11 @@ class GraphDMLImpl : public ml::webnn::mojom::Graph {
const std::vector<UINT>& nchwOutputDims);
void AddOutput(const std::string&, UINT64);

void CompileGraph();
void OnGraphCompiled(BuildCallback callback,
ComPtr<gpgmm::d3d12::ResourceAllocation> constant_resource,
ConstantsInfoPtr constants_info);

void OnWaitForBuildSignal(BuildCallback callback);
void OnWaitForComputeSignal(ComputeCallback callback);