diff --git a/components/ml/mojom/web_platform_model.mojom b/components/ml/mojom/web_platform_model.mojom index 102ff7a5393812..d6bf4575efc0ea 100644 --- a/components/ml/mojom/web_platform_model.mojom +++ b/components/ml/mojom/web_platform_model.mojom @@ -40,6 +40,8 @@ enum DevicePreference { // GPU. We do not directly return an error in this case to avoid directly // exposing user's GPU information. kGpu = 2, + // Prefers running the model inference on an NPU. + kNpu = 3, }; [Stable, Extensible] diff --git a/components/ml/mojom/webnn_graph.mojom b/components/ml/mojom/webnn_graph.mojom index 6f00c59e7120bd..9a2c88e745bf10 100644 --- a/components/ml/mojom/webnn_graph.mojom +++ b/components/ml/mojom/webnn_graph.mojom @@ -6,7 +6,129 @@ module ml.webnn.mojom; import "mojo/public/mojom/base/shared_memory.mojom"; -// Represents the `MLOperandType` that is the type of operand. +// Enumeration for every distinct operator, regardless of how they may be +// grouped together in the mojom messages. There should be a 1:1 mapping to +// each OperatorKind::MLOperator::OperatorKind in ml_operator.h, and every +// callable operator function in ml_graph_builder.idl should have an entry. +// Note the OperationInfo::Tag enumerations are distinct and are subject to +// message grouping, such as combining multiple elementwise binary messages +// into a single struct. +[Stable, Extensible] +enum OperatorType { + [Default] kUnknown = 0, // Used to throw exception for unsupported types in blink. + + // Dot product operations. + kConv2d, + kConvTranspose2d, + kGemm, + kMatmul, // Handled by GEMM. + + // ElementWise binary operations + kAdd, + kSub, + kMul, + kDiv, + kMax, + kMin, + kPow, + + // Elementwise logical binary operations + kEqual, + kGreater, + kLesser, + + // Fuseable elementwise activation operations. + kRelu, + kElu, + kPrelu, + kLeakyRelu, + kClamp, + kSigmoid, + kHardSigmoid, + kHardSwish, + kLinear, + kSoftplus, + kSoftsign, + kSoftmax, + kTanh, + + // Pooling operations + kAveragePool2d, + kL2Pool2d, + kMaxPool2d, + + // Elementwise unary operations + kIdentity, + kAbs, + kNeg, + kExp, + kLog, + kSqrt, + kSin, + kCos, + kTan, + kErf, + kFloor, + kCeil, + kReciprocal, + kLogicalNot, + + // Trinary elementwise operations + kElementWiseIf, + + // Shape reinterpreting operations + kReshape, + kSqueeze, + kUnsqueeze, + kFlattenTo2d, + + // Shape modification operations. + kConcat, + kSlice, + kSplit, + kTranspose, + kPad, + kExpand, + kGather, + kResample2d, + + // Reduction operations + kReduceL1, + kReduceL2, + kReduceLogSum, + kReduceLogSumExp, + kReduceMax, + kReduceMean, + kReduceMin, + kReduceProduct, + kReduceSum, + kReduceSumSquare, + + // Normalization operations + kInstanceNormalization, + kMeanVarianceNormalization, + + // Miscellaneous operations + kArgMax, + kArgMin, + kCast, + kFillSequence, + kTriangularMatrix, + + // Iterative GEMM operations + kGru, + kGruCell, + kLstm, + kLstmCell, + + // Quantized operators + kConv2dInteger, + kMatmulInteger, + kDequantizeLinear, + kDynamicQuantizeLinear, +}; + +// Represents the element data type of a tensor, or `MLOperandType`. [Stable, Extensible] enum OperandType { [Default] kFloat32, @@ -15,6 +137,8 @@ enum OperandType { kUint32, kInt8, kUint8, + kInt64, + kUint64, }; // Represents the `MLOperandDescriptor` which describes not only input and @@ -56,11 +180,22 @@ struct Clamp { [Stable, Extensible] enum InputOperandLayout { [Default] kNchw, kNhwc }; -// Represents the `MLConv2dFilterOperandLayout` that specifies the layout format -// of filter tensor. O is the output channel count, i is input count / groups, H -// is height and W is the width of the tensor. +// Represents the values from `MLConv2dFilterOperandLayout` or +// `MLConvTranspose2dFilterOperandLayout` that specifies the filter tensor's +// layout format. +// - O == filter output channel count +// - I == filter input channel count +// - H == height +// - W == width +// +// For forward convolution: +// - filter input channel count == input tensor channel count / groups +// - filter output channel count == output tensor channel count +// For backward convolution: +// - filter input channel count == input tensor channel count +// - filter output channel count == output tensor channel count / groups [Stable, Extensible] -enum Conv2dFilterOperandLayout { [Default] kOihw, kHwio, kOhwi, kIhwo }; +enum Conv2dFilterOperandLayout { [Default] kOihw, kIohw, kHwoi, kHwio, kOhwi, kIhwo }; // Represents the `MLAutoPad`. `Explicit` means that the values in the // options.padding array should be used for input padding, with `SameUpper` @@ -68,6 +203,15 @@ enum Conv2dFilterOperandLayout { [Default] kOihw, kHwio, kOhwi, kIhwo }; [Stable, Extensible] enum AutoPad { [Default] kExplicit, kSameUpper, kSameLower }; +[Stable, Extensible] +enum TriangularPart { [Default] kUpper, kLower }; + +[Stable, Extensible] +enum PaddingMode { [Default] kConstant, kEdge, kReflection, kSymmetric }; + +[Stable, Extensible] +enum InterpolationMode { [Default] kNearestNeighbor, kLinear }; + // Represents the `MLConv2dOptions`. [Stable] struct Conv2dOptions { @@ -105,39 +249,87 @@ struct Conv2dOptions { // the struct's output index. [Stable] struct Conv2d { + OperatorType operator_type@0; // The type of operation, conv2d or conv2dTranspose. // The index of input operand in the model's operand array. - uint64 input_index@0; + uint64 input_index@1; // The index of filter operand in the model's operand array. - uint64 filter_index@1; - // MLConv2dOptions is mapping to Conv2dOptions. - Conv2dOptions options@2; + uint64 filter_index@2; + // The index of input operand's zero point in the model's operand array. + // Optional, and only present for conv2dInteger. + uint64 input_zero_point_index@3; + // The index of filter operand's zero point in the model's operand array. + // Optional, and only present for conv2dInteger. + uint64 filter_zero_point_index@4; + // MLConv2dOptions/MLConvTranspose2dOptions/MLConv2dIntegerOptions map to Conv2dOptions. + Conv2dOptions options@5; // The index of output operand in the model's operand array. - uint64 output_index@3; + uint64 output_index@6; }; -// Represents element-wise binary operation, the `kUnknown` is used to throw -// exception for unsupported type in blink. -[Stable, Extensible] -enum ElementWiseBinaryType { - [Default] kUnknown = 0, - kAdd, - kSub, - kMul, - kDiv, - kMax, - kMin, - kPower, +// Element-wise unary operators: +// - identity: return same output as input. +// - abs +// - ceil +// - cos +// - exp +// - erf +// - floor +// - hardSwish +// - log +// - logicalNot +// - neg +// - reciprocal +// - sigmoid +// - sin +// - softsign +// - sqrt +// - tan +// - tanh +// +[Stable] +struct ElementWiseUnary { + OperatorType operator_type@0; // The type of unary operation. + uint64 input_index@1; // The index of a operand in the model's operand array. + uint64 output_index@2; // The index of output operand in the model's operand array. }; -// Add element-wise binary addition, subtraction, multiplication, division -// defined in `ElementWiseBinaryType`. Corresponds to `MLOperand add( -// MLOperand a, MLOperand b)` function in `MLGraphBuilder`. The a `MLOperand` is -// mapping to the struct's a index, The b `MLOperand` is mapping to the struct's -// b index, the return `MLOperand` is mapping to the struct's output index. +// Element-wise unary operators with up to two parameters: +// - elu MLEluOptions +// - leakyRelu MLLeakyReluOptions +// - linear MLLinearOptions +// - hardSigmoid MLHardSigmoidOptions +// - softplus MLSoftplusOptions [Stable] -struct ElementWiseBinary { +struct ElementWiseUnaryTwoParameter { + OperatorType operator_type@0; // The type of unary operation. + uint64 input_index@1; // The index of a operand in the model's operand array. + uint64 output_index@2; // The index of output operand in the model's operand array. + float first_parameter@3; + float second_parameter@4; +}; + +// Element-wise binary operators: +// - add: Add the values of the two input tensors, element-wise. +// - sub: Subtract the values of the second input tensor from the values of the first input tensor, element-wise. +// - mul: Multiply the values of the two input tensors, element-wise. +// - div: Divide the values of the first input tensor with the values of the second tensor, element-wise. +// - max: Select the greater values of the two input tensors, element-wise. +// - min: Select the lesser values of the two input tensors, element-wise. +// - pow: Compute the values of the values of the first input tensor to the power of the values of the second input tensor, element-wise. +// - equal: Return bool if a == b per element. +// - greater: Return bool if a > b per element. +// - lesser: Return bool if a < b per element. +// - prelu: if x >= 0 then x else a * x. +// +// The `OperatorType` specifies the specific one. e.g. Add corresponds to +// `MLOperand add(MLOperand a, MLOperand b)` function in `MLGraphBuilder`. +// The 'a' `MLOperand` is mapped to the struct's a index, and the 'b' +// `MLOperand` is mapped to the struct's b index, while the returned +// `MLOperand` is mapped to the struct's output index. +[Stable] +struct ElementWiseBinary { // The type of binary operation. - ElementWiseBinaryType type@0; + OperatorType operator_type@0; // The index of a operand in the model's operand array. uint64 a_index@1; // The index of b operand in the model's operand array. @@ -147,9 +339,10 @@ struct ElementWiseBinary { }; // Represents the `MLGemmOptions`. +// Also handles MatMul and MatMulInteger. [Stable] struct GemmOptions { - // The largest possible value for type uint64 identify the option index with + // The largest possible value for type uint64 identifies the empty optional index with // std::numeric_limits::max(). uint64 c_index@0; // The scalar multiplier for the first input. @@ -168,24 +361,13 @@ struct GemmOptions { // return `MLOperand` is mapping to the struct's output index. [Stable] struct Gemm { - // The index of a operand in the model's operand array. - uint64 a_index@0; - // The index of b operand in the model's operand array. - uint64 b_index@1; - // MLGemmOptions is mapping to GemmOptions. - GemmOptions options@2; - // The index of output operand in the model's operand array. - uint64 output_index@3; -}; - -// Represents the `MLPool2dType`, the `kUnknown` is used to throw exception for -// unsupported type in blink. -[Stable, Extensible] -enum Pool2dType { - [Default] kUnknown = 0, - kAveragePool2d, - kMaxPool2d, - kL2Pool2d + OperatorType operator_type@0; // The type of operation, conv2d or conv2dTranspose. + uint64 a_index@1; // The index of a operand in the model's operand array. + uint64 b_index@2; // The index of b operand in the model's operand array. + uint64 a_zero_point_index@3; // The index of a zero point operand in the model's operand array. + uint64 b_zero_point_index@4; // The index of b zero point operand in the model's operand array. + GemmOptions options@5; // MLGemmOptions is mapping to GemmOptions. + uint64 output_index@6; // The index of output operand in the model's operand array. }; [Stable, Extensible] @@ -209,7 +391,7 @@ struct Pool2dOptions { InputOperandLayout layout@5 = InputOperandLayout.kNchw; // The rounding function used to compute the output shape. RoundingType rounding_type@6 = RoundingType.kFloor; - // The sizes of the two spacial dimensions of the output tensor. + // The sizes of the two spatial dimensions of the output tensor. array output_sizes@7; }; @@ -222,7 +404,7 @@ struct Pool2dOptions { [Stable] struct Pool2d { // The type of pool2d operation. - Pool2dType type@0; + OperatorType operator_type@0; // The index of input operand in the model's operand array. uint64 input_index@1; // MLPool2dOptions is mapping to Pool2dOptions. @@ -243,6 +425,12 @@ struct Relu { uint64 output_index@1; }; +// Reshape operations: +// - Reshape +// - Squeeze +// - Unsqueeze +// - FlattenTo2d +// // Corresponds to `MLOperand reshape(MLOperand input, sequence shape)` // function in `MLGraphBuilder`.The input `MLOperand` is mapping to the struct's // input index, the `shape` is set with the dimensions of output OperandDesc. @@ -267,6 +455,181 @@ struct Softmax { uint64 output_index@1; }; +// Elementwise binary logical comparison operations + +// Shape modification operations. + +[Stable] +struct Concat { + array input_indices@0; + uint32 axis@1; + uint64 output_index@2; // Index of output operand in model's operand array. +}; + +[Stable] +struct Slice { + uint64 input_index@0; + array starts@1; + array sizes@2; + uint64 output_index@3; // Index of output operand in model's operand array. +}; + +[Stable] +struct Split { + uint64 input_index@0; + uint32 axis@1; + array output_indices@2; +}; + +[Stable] +struct Transpose { + uint64 input_index@0; + array permutation@1; + uint64 output_index@2; // Index of output operand in model's operand array. +}; + +[Stable] +struct Pad { + uint64 input_index@0; + array beginningPadding@1; + array endingPadding@2; + PaddingMode mode@3 = PaddingMode.kConstant; + float value@4; + uint64 output_index@5; // Index of output operand in model's operand array. +}; + +[Stable] +struct Expand { + uint64 input_index@0; + array new_shape@1; + uint64 output_index@2; // Index of output operand in model's operand array. +}; + +[Stable] +struct Gather { + uint64 input_index@0; + uint64 indices_index@1; + uint32 axis@2; + uint64 output_index@3; // Index of output operand in model's operand array. +}; + +// Reduction operations + +[Stable] +struct Reduce { + OperatorType operator_type@0; // The type of unary operation. + uint64 input_index@1; + array axes@2; + bool keep_dimensions@3 = false; + uint64 output_index@4; // Index of output operand in model's operand array. +}; + +[Stable] +struct ArgMinMax { + OperatorType operator_type@0; + uint64 input_index@1; + uint32 axis@2 = 0; + bool keep_dimensions@3 = false; + bool select_last_index@4 = false; + uint64 output_index@5; // Index of output operand in model's operand array. +}; + +[Stable] +struct Cast { + uint64 input_index@0; + OperandType data_type@1 = OperandType.kFloat32; + uint64 output_index@2; // Index of output operand in model's operand array. +}; + +[Stable] +struct InstanceNormalization { + uint64 input_index@0; + uint64 scale_index@1; + uint64 bias_index@2; + float epsilon@3; + InputOperandLayout layout@4 = InputOperandLayout.kNchw; + uint64 output_index@5; // Index of output operand in model's operand array. +}; + +[Stable] +struct MeanVarianceNormalization { + uint64 input_index@0; + uint64 mean_index@1; + uint64 variance_index@2; + uint64 scale_index@3; + uint64 bias_index@4; + float epsilon@5; + array axes@6; + uint64 output_index@7; // Index of output operand in model's operand array. +}; + +[Stable] +struct FillSequence { + uint64 input_index@0; + float start@1; + float delta@2; + uint64 output_index@3; // Index of output operand in model's operand array. +}; + +[Stable] +struct TriangularMatrixOptions { + TriangularPart triangular_part@0 = TriangularPart.kUpper; + int32 diagonal_delta@1; +}; + +[Stable] +struct TriangularMatrix { + uint64 input_index@0; + TriangularMatrixOptions options@1; + uint64 output_index@2; // Index of output operand in model's operand array. +}; + +[Stable] +struct ElementWiseIf { + uint64 condition_index@0; + uint64 true_value_index@1; // Index of operand model's operand array + uint64 false_value_index@2; // Index of operand model's operand array + uint64 output_index@3; // Index of output operand in model's operand array. +}; + +[Stable] +struct Resample2d { + uint64 input_index@0; + InterpolationMode interpolation_mode@1 = InterpolationMode.kNearestNeighbor; + array scales@2; + array axes@3; + // No need to store sizes because the output dimensions suffice. + uint64 output_index@4; // Index of output operand in model's operand array. +}; + +[Stable] +struct DequantizeLinear { + // The index of input operand in the model's operand array. + uint64 input_index@0; + uint64 scale_index@1; + uint64 zero_point_index@2; + // The index of output operand in the model's operand array. + uint64 output_index@3; +}; + +[Stable] +struct DynamicQuantizeLinear { + // The index of input operand in the model's operand array. + uint64 input_index@0; + // The indices of output operand in the model's operand array. + uint64 output_index@1; + uint64 output_scale_index@2; + uint64 output_zero_point_index@3; +}; + +// TODO:::Implement +/* +gru +grulCell +lstm +lstmCell +*/ + // Describe the information of inputs / outputs data in shared buffer. [Stable] struct MemoryInfo { @@ -290,12 +653,38 @@ union OperationInfo { // Keep the order as the same as build methods of `MLGraphBuilder`. Clamp clamp; Conv2d conv2d; + ElementWiseUnary element_wise_unary; ElementWiseBinary element_wise_binary; + ElementWiseUnaryTwoParameter element_wise_unary_two_parameter; Gemm gemm; Pool2d pool2d; Relu relu; Reshape reshape; Softmax softmax; + Concat concat; + Slice slice; + Split split; + Transpose transpose; + Pad pad; + Expand expand; + Gather gather; + Reduce reduce; + Resample2d resample2d; + ArgMinMax arg_min_max; + Cast cast; + InstanceNormalization instance_normalization; + MeanVarianceNormalization mean_variance_normalization; + FillSequence fill_sequence; + TriangularMatrix triangular_matrix; + ElementWiseIf element_wise_if; + DequantizeLinear dequantize_linear; + DynamicQuantizeLinear dynamic_quantize_linear; + + // TODO:::Implement + // gru + // grulCell + // lstm + // lstmCell }; // Represents the input and output operands in the model. @@ -319,7 +708,7 @@ struct ConstantsInfo { [Stable] struct ModelInfo { // All operands in the model. - array operands@0; + map operands@0; // All inputs in the model. array inputs@1; // All outputs in the models. diff --git a/content/browser/ml/webnn/dml/adapter_dml.cc b/content/browser/ml/webnn/dml/adapter_dml.cc index f5c09bd51a2d47..dea25385b02335 100644 --- a/content/browser/ml/webnn/dml/adapter_dml.cc +++ b/content/browser/ml/webnn/dml/adapter_dml.cc @@ -4,43 +4,106 @@ #include "content/browser/ml/webnn/dml/adapter_dml.h" +// TODO::: +#pragma optimize("", off) + namespace content::webnn { AdapterDML::AdapterDML(ComPtr hardware_adapter) : hardware_adapter_(hardware_adapter) {} +AdapterDML::AdapterDML(ComPtr dxcore_hardware_adapter) + : dxcore_hardware_adapter_(dxcore_hardware_adapter) {} + AdapterDML::~AdapterDML() = default; HRESULT AdapterDML::Initialize() { + // Select either the DXCore adapter or DXGI one, depending on which is available. + IUnknown* dxcore_adapter = dxcore_hardware_adapter_.Get(); + IUnknown* dxgi_adapter = hardware_adapter_.Get(); + IUnknown* adapter = dxcore_adapter ? dxcore_adapter : dxgi_adapter; HRESULT hr = - D3D12CreateDevice(hardware_adapter_.Get(), D3D_FEATURE_LEVEL_11_0, + D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_1_0_CORE, IID_PPV_ARGS(&d3d12_device_)); if (FAILED(hr)) { return hr; } - hr = DMLCreateDevice(d3d12_device_.Get(), DML_CREATE_DEVICE_FLAG_NONE, + // If the D3D debug layer is enabled (e.g. via dxcpl.exe), then also enable + // the DirectML debug layer accordingly. + DML_CREATE_DEVICE_FLAGS dml_create_device_flags = DML_CREATE_DEVICE_FLAG_NONE; + + ComPtr debug_device; + d3d12_device_->QueryInterface(IID_PPV_ARGS(&debug_device)); // Ignore failure + bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); + + if (is_d3d12_debug_layer_enabled) { + dml_create_device_flags |= DML_CREATE_DEVICE_FLAG_DEBUG; + } + + hr = DMLCreateDevice(d3d12_device_.Get(), dml_create_device_flags, IID_PPV_ARGS(&dml_device_)); + + // If the DirectML debug layer couldn't be loaded, create the device again + // without it. The debug layer is not essential. + if (hr == DXGI_ERROR_SDK_COMPONENT_MISSING) { + dml_create_device_flags &= ~DML_CREATE_DEVICE_FLAG_DEBUG; + hr = DMLCreateDevice(d3d12_device_.Get(), dml_create_device_flags, + IID_PPV_ARGS(&dml_device_)); + } if (FAILED(hr)) { return hr; } - DXGI_ADAPTER_DESC1 adapter_desc; - hardware_adapter_->GetDesc1(&adapter_desc); - if (adapter_desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) { - adapter_type_ = AdapterType::kCPU; - } else { - D3D12_FEATURE_DATA_ARCHITECTURE arch = {}; - hr = d3d12_device_->CheckFeatureSupport(D3D12_FEATURE_ARCHITECTURE, &arch, - sizeof(arch)); - if (FAILED(hr)) { - return hr; + if (hardware_adapter_) { + DXGI_ADAPTER_DESC1 adapter_desc; + hardware_adapter_->GetDesc1(&adapter_desc); + if (adapter_desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) { + adapter_type_ = AdapterType::kCPU; + } else { + D3D12_FEATURE_DATA_ARCHITECTURE arch = {}; + hr = d3d12_device_->CheckFeatureSupport(D3D12_FEATURE_ARCHITECTURE, &arch, + sizeof(arch)); + if (FAILED(hr)) { + return hr; + } + adapter_type_ = + (arch.UMA) ? AdapterType::kIntegratedGPU : AdapterType::kDiscreteGPU; + } + DXGI_ADAPTER_DESC desc = {}; + if (SUCCEEDED(hardware_adapter_->GetDesc(&desc))) { + device_name_.assign(desc.Description); + } + } else if (dxcore_hardware_adapter_) { + BOOL property_value; + if (SUCCEEDED(dxcore_hardware_adapter_->GetProperty( + DXCoreAdapterProperty::IsHardware, &property_value)) && + !property_value) { + adapter_type_ = AdapterType::kCPU; // Could be software WARP. + } else if (!dxcore_hardware_adapter_->IsAttributeSupported( + DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)) { + adapter_type_ = AdapterType::kNPU; // Compute only device. + } else if (SUCCEEDED(dxcore_hardware_adapter_->GetProperty( + DXCoreAdapterProperty::IsIntegrated, &property_value)) && + property_value) { + adapter_type_ = AdapterType::kIntegratedGPU; + } else { + adapter_type_ = AdapterType::kDiscreteGPU; + } + + size_t property_buffer_size = 0; + if (SUCCEEDED(dxcore_hardware_adapter_->GetPropertySize( + DXCoreAdapterProperty::DriverDescription, &property_buffer_size))) { + std::string device_name(property_buffer_size, '\0'); + if (SUCCEEDED(dxcore_hardware_adapter_->GetProperty( + DXCoreAdapterProperty::DriverDescription, property_buffer_size, + device_name.data()))) { + device_name_.assign(device_name.begin(), device_name.end()); + } } - adapter_type_ = - (arch.UMA) ? AdapterType::kIntegratedGPU : AdapterType::kDiscreteGPU; } - command_queue_ = base::MakeRefCounted(); + command_queue_ = std::make_unique(); hr = command_queue_->Initialize(d3d12_device_.Get()); if (FAILED(hr)) { return hr; @@ -67,7 +130,7 @@ HRESULT AdapterDML::Initialize() { } AdapterType AdapterDML::GetAdapterType() { - DCHECK(adapter_type_ != AdapterType::kUnknow); + DCHECK(adapter_type_ != AdapterType::kUnknown); return adapter_type_; } @@ -81,9 +144,9 @@ ComPtr AdapterDML::GetDMLDevice() const { return dml_device_; } -scoped_refptr AdapterDML::GetCommandQueue() const { - DCHECK(command_queue_.get() != nullptr); - return command_queue_; +CommandQueue* AdapterDML::GetCommandQueue() const { + DCHECK(command_queue_ != nullptr); + return command_queue_.get(); } ComPtr AdapterDML::GetResourceAllocator() { diff --git a/content/browser/ml/webnn/dml/adapter_dml.h b/content/browser/ml/webnn/dml/adapter_dml.h index 55af6fb3513a1a..9a60f065c25a80 100644 --- a/content/browser/ml/webnn/dml/adapter_dml.h +++ b/content/browser/ml/webnn/dml/adapter_dml.h @@ -5,8 +5,10 @@ #ifndef CONTENT_BROWSER_ML_WEBNN_DML_ADAPTER_DML_H_ #define CONTENT_BROWSER_ML_WEBNN_DML_ADAPTER_DML_H_ +#include #include #include +#include #include #include "DirectML.h" @@ -23,12 +25,14 @@ enum class AdapterType { kDiscreteGPU = 0, kIntegratedGPU = 1, kCPU = 2, - kUnknow = 3, + kNPU = 3, + kUnknown = 4, }; class AdapterDML final : public base::RefCounted { public: explicit AdapterDML(ComPtr hardware_adapter); + explicit AdapterDML(ComPtr dxcore_hardware_adapter); AdapterDML(const AdapterDML&) = delete; AdapterDML& operator=(const AdapterDML&) = delete; @@ -37,17 +41,24 @@ class AdapterDML final : public base::RefCounted { AdapterType GetAdapterType(); ComPtr GetD3D12Device() const; ComPtr GetDMLDevice() const; - scoped_refptr GetCommandQueue() const; + CommandQueue* GetCommandQueue() const; ComPtr GetResourceAllocator(); private: friend class base::RefCounted; ~AdapterDML(); + // One of these adapters will be non-null. + // DXGI is older and more broadly supported, capable of enumerating GPU's, + // but NPU's are only enumerated by DXCore. ComPtr hardware_adapter_; - AdapterType adapter_type_ = AdapterType::kUnknow; + ComPtr dxcore_hardware_adapter_; + + AdapterType adapter_type_ = AdapterType::kUnknown; + std::wstring device_name_; ComPtr d3d12_device_; - scoped_refptr command_queue_; + std::unique_ptr command_queue_; + // Represents a DirectML device, which is used to create operators, binding // tables, command recorders. ComPtr dml_device_; diff --git a/content/browser/ml/webnn/dml/command_queue.cc b/content/browser/ml/webnn/dml/command_queue.cc index c713664f52d1c1..430ca1e04dea55 100644 --- a/content/browser/ml/webnn/dml/command_queue.cc +++ b/content/browser/ml/webnn/dml/command_queue.cc @@ -4,36 +4,40 @@ #include "content/browser/ml/webnn/dml/command_queue.h" -namespace content::webnn { +#include "base/logging.h" +#include "base/memory/ptr_util.h" -CommandQueue::~CommandQueue() = default; +namespace content::webnn { CommandQueue::CommandQueue() {} +CommandQueue::~CommandQueue() = default; + HRESULT CommandQueue::Initialize(ID3D12Device* d3d12_device) { D3D12_COMMAND_QUEUE_DESC command_queue_desc = {}; - command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE; command_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; HRESULT hr = d3d12_device->CreateCommandQueue(&command_queue_desc, IID_PPV_ARGS(&command_queue_)); if (FAILED(hr)) { + DLOG(ERROR) << "Failed to create ID3D12CommandQueue: " + << logging::SystemErrorCodeToString(hr); return hr; } hr = d3d12_device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence_)); if (FAILED(hr)) { + DLOG(ERROR) << "Failed to create ID3D12Fence: " + << logging::SystemErrorCodeToString(hr); return hr; } - fence_event_ = CreateEvent(nullptr, FALSE, FALSE, nullptr); - DCHECK(fence_event_ != nullptr); - return S_OK; -} + fence_event_.Set(CreateEvent(nullptr, /*bManualReset=*/FALSE, + /*bInitialState=*/FALSE, nullptr)); + CHECK(fence_event_.is_valid()); -void CommandQueue::ReferenceUntilCompleted(ComPtr object) { - QueuedObject object_ref = {last_fence_value_, std::move(object)}; - queued_object_refs_.push_back(object_ref); + return S_OK; } HRESULT CommandQueue::ExecuteCommandLists( @@ -44,37 +48,57 @@ HRESULT CommandQueue::ExecuteCommandLists( return command_queue_->Signal(fence_.Get(), last_fence_value_); } -void CommandQueue::Wait() { - if (fence_->GetCompletedValue() >= last_fence_value_) { - return; +void CommandQueue::OnObjectSignaled(HANDLE object) { + CHECK_EQ(object, fence_event_.get()); + while (!queued_callbacks_.empty() && + queued_callbacks_.front().fence_value <= fence_->GetCompletedValue()) { + std::move(queued_callbacks_.front().callback).Run(); + queued_callbacks_.pop_front(); } - HRESULT hr = fence_->SetEventOnCompletion(last_fence_value_, fence_event_); - if (FAILED(hr)) { - return; - } - WaitForSingleObject(fence_event_, INFINITE); } -void CommandQueue::ReleaseCompletedResources() { - uint64_t completed_value = fence_->GetCompletedValue(); - while (!queued_object_refs_.empty() && - queued_object_refs_.front().fence_value <= completed_value) { - queued_object_refs_.pop_front(); +HRESULT CommandQueue::WaitAsync(base::OnceClosure callback) { + if (!object_watcher_.IsWatching()) { + CHECK(object_watcher_.StartWatchingMultipleTimes(fence_event_.get(), this)); } + + HRESULT hr = + fence_->SetEventOnCompletion(last_fence_value_, fence_event_.get()); + if (FAILED(hr)) { + DLOG(ERROR) << "Failed to set event on completion : " + << logging::SystemErrorCodeToString(hr); + return hr; + }; + queued_callbacks_.push_back({last_fence_value_, std::move(callback)}); + return S_OK; } -CommandQueue::QueuedObject::QueuedObject(uint64_t fence_value, - ComPtr object) { - this->fence_value = fence_value; - this->object = std::move(object); +void CommandQueue::ReferenceUntilCompleted(ComPtr object) { + queued_objects_.push_back({last_fence_value_, std::move(object)}); } -CommandQueue::QueuedObject::QueuedObject(const QueuedObject& other) { - this->fence_value = other.fence_value; - this->object = std::move(other.object); +void CommandQueue::ReleaseCompletedResources() { + uint64_t completed_value = fence_->GetCompletedValue(); + while (!queued_objects_.empty() && + queued_objects_.front().fence_value <= completed_value) { + queued_objects_.pop_front(); + } } -CommandQueue::QueuedObject::QueuedObject() = default; +CommandQueue::QueuedObject::QueuedObject(uint64_t fence_value, + ComPtr object) + : fence_value(fence_value), object(std::move(object)) {} +CommandQueue::QueuedObject::QueuedObject(QueuedObject&& other) = default; +CommandQueue::QueuedObject& CommandQueue::QueuedObject::operator=( + QueuedObject&& other) = default; CommandQueue::QueuedObject::~QueuedObject() = default; +CommandQueue::QueuedCallback::QueuedCallback(uint64_t fence_value, + base::OnceClosure callback) + : fence_value(fence_value), callback(std::move(callback)) {} +CommandQueue::QueuedCallback::QueuedCallback(QueuedCallback&& other) = default; +CommandQueue::QueuedCallback& CommandQueue::QueuedCallback::operator=( + QueuedCallback&& other) = default; +CommandQueue::QueuedCallback::~QueuedCallback() = default; + } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/command_queue.h b/content/browser/ml/webnn/dml/command_queue.h index 14b743088c4d80..dd68cda99f5317 100644 --- a/content/browser/ml/webnn/dml/command_queue.h +++ b/content/browser/ml/webnn/dml/command_queue.h @@ -10,6 +10,9 @@ #include "DirectML.h" #include "base/memory/ref_counted.h" +#include "base/functional/callback_forward.h" +#include "base/win/object_watcher.h" +#include "base/win/scoped_handle.h" namespace content::webnn { @@ -18,44 +21,63 @@ using Microsoft::WRL::ComPtr; // There is only one D3D12 command queue wrapped an existing queue that will // be shared with WebNN context, and provides a waitable fence which is signaled // with a increasing value once the execution complete on the GPU. -class CommandQueue final : public base::RefCounted { +class CommandQueue : public base::win::ObjectWatcher::Delegate { public: - CommandQueue(); + CommandQueue(); CommandQueue(const CommandQueue&) = delete; CommandQueue& operator=(const CommandQueue&) = delete; + ~CommandQueue() override; HRESULT Initialize(ID3D12Device* d3d12_device); void ReferenceUntilCompleted(ComPtr object); - HRESULT ExecuteCommandLists(std::vector); - // Queues a wait to block the GPU until the fence is signaled with the last - // value. - void Wait(); void ReleaseCompletedResources(); + HRESULT ExecuteCommandLists(std::vector); + + // It's an asynchronous method for DirectML graph implementation, which will + // not block the CPU. + HRESULT WaitAsync(base::OnceClosure callback); + private: - friend class base::RefCounted; - ~CommandQueue(); - struct QueuedObject { - QueuedObject(); - QueuedObject(const QueuedObject& other); +struct QueuedObject { + QueuedObject() = delete; QueuedObject(uint64_t fence_value, ComPtr object); + QueuedObject(QueuedObject&& other); + QueuedObject& operator=(QueuedObject&& other); ~QueuedObject(); - uint64_t fence_value; + uint64_t fence_value = 0; ComPtr object; }; - std::deque queued_object_refs_; + std::deque queued_objects_; + + struct QueuedCallback { + QueuedCallback() = delete; + QueuedCallback(uint64_t fence_value, base::OnceClosure callback); + QueuedCallback(QueuedCallback&& other); + QueuedCallback& operator=(QueuedCallback&& other); + ~QueuedCallback(); + + uint64_t fence_value = 0; + base::OnceClosure callback; + }; + std::deque queued_callbacks_; + + // Implements base::win::ObjectWatcher::Delegate. + void OnObjectSignaled(HANDLE object) override; ComPtr command_queue_; - // the fence value used to watch the progression of GPU execution on a queue - // that is incremented by one time. This way to know if something is done - // executing, we just need to compare its value with the currently completed - // value. + + // The increasing fence value is used to track the progress of GPU execution + // work. Comparing it with the fence's completed value can indicate whether + // the work has been completed. uint64_t last_fence_value_ = 0; ComPtr fence_; - HANDLE fence_event_ = nullptr; + + base::win::ScopedHandle fence_event_; + base::win::ObjectWatcher object_watcher_; }; } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/command_recorder.cc b/content/browser/ml/webnn/dml/command_recorder.cc index 47ffb22adfdf3d..8f9d8e2c768ba8 100644 --- a/content/browser/ml/webnn/dml/command_recorder.cc +++ b/content/browser/ml/webnn/dml/command_recorder.cc @@ -4,6 +4,8 @@ #include "content/browser/ml/webnn/dml/command_recorder.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/adapter_dml.h" #include "content/browser/ml/webnn/dml/execution_resources.h" @@ -16,7 +18,7 @@ CommandRecorder::CommandRecorder(scoped_refptr adapter, : adapter_(std::move(adapter)), dml_device_(std::move(dml_device)) {} void CommandRecorder::ResourceBarrier( - std::vector barriers) { + const std::vector& barriers) { command_list_->ResourceBarrier(barriers.size(), barriers.data()); } @@ -35,13 +37,13 @@ HRESULT CommandRecorder::Initialize() { return hr; } - hr = d3d12_device_->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, + hr = d3d12_device_->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&command_allocator_)); if (FAILED(hr)) { return hr; } - hr = d3d12_device_->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, + hr = d3d12_device_->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE, command_allocator_.Get(), nullptr, IID_PPV_ARGS(&command_list_)); if (FAILED(hr)) { @@ -65,6 +67,7 @@ HRESULT CommandRecorder::InitializeGraph( GraphDMLImpl* graph, IDMLCompiledOperator* compiled_operator, const DML_BINDING_DESC& input_array_binding) { + TRACE_EVENT0("gpu", "CommandRecorder::InitializeGraph"); // Reset the initializer to reference the compiled operator. IDMLCompiledOperator* ops[] = {compiled_operator}; HRESULT hr = operator_initializer_->Reset(ARRAYSIZE(ops), ops); @@ -164,14 +167,40 @@ HRESULT CommandRecorder::ExecuteGraph( IDMLCompiledOperator* compiled_operator, const std::vector& input_bindings, const std::vector& output_bindings) { + TRACE_EVENT0("gpu", "CommandRecorder::ExecuteGraph"); + DCHECK(mBindingTable != nullptr); // Bind and execute the operator on the GPU. // Reset the binding table to bind for the operator we want to execute (it // was previously used to bind for the initializer). mBindingTableDesc.Dispatchable = compiled_operator; - mBindingTable->Reset(&mBindingTableDesc); DML_BINDING_PROPERTIES binding_properties = compiled_operator->GetBindingProperties(); + UINT descriptorCount = binding_properties.RequiredDescriptorCount; + if (descriptorCount > mBindingTableDesc.SizeInDescriptors) { + // Need to reallocate the descriptors heap. + mDescriptorHeap.Reset(); + D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc{}; + descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + descriptorHeapDesc.NumDescriptors = descriptorCount; + descriptorHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + HRESULT hr = d3d12_device_->CreateDescriptorHeap( + &descriptorHeapDesc, IID_PPV_ARGS(&mDescriptorHeap)); + if (FAILED(hr)) { + return hr; + } + mBindingTableDesc.CPUDescriptorHandle = + mDescriptorHeap->GetCPUDescriptorHandleForHeapStart(); + mBindingTableDesc.GPUDescriptorHandle = + mDescriptorHeap->GetGPUDescriptorHandleForHeapStart(); + // The size of the binding table, in descriptors. This is the maximum number + // of descriptors that DirectML is permitted to write, from the start of + // both the supplied CPU and GPU descriptor handles. + mBindingTableDesc.SizeInDescriptors = descriptorCount; + } + + mBindingTable->Reset(&mBindingTableDesc); + UINT64 temporary_resource_size = binding_properties.TemporaryResourceSize; if (temporary_resource_size != 0) { ID3D12Resource* temporary_resource = @@ -205,7 +234,7 @@ HRESULT CommandRecorder::ExecuteGraph( } void CommandRecorder::CloseAndExecute() const { - const auto& command_queue = adapter_->GetCommandQueue(); + auto* command_queue = adapter_->GetCommandQueue(); HRESULT hr = command_list_->Close(); if (FAILED(hr)) { return; diff --git a/content/browser/ml/webnn/dml/command_recorder.h b/content/browser/ml/webnn/dml/command_recorder.h index 7baf81e682327c..ac46eb00017f58 100644 --- a/content/browser/ml/webnn/dml/command_recorder.h +++ b/content/browser/ml/webnn/dml/command_recorder.h @@ -32,7 +32,7 @@ class CommandRecorder final { HRESULT Initialize(); - void ResourceBarrier(std::vector barriers); + void ResourceBarrier(const std::vector& barriers); void CopyBufferRegion(ID3D12Resource* dst_buffer, uint64_t dst_offset, diff --git a/content/browser/ml/webnn/dml/execution_context.cc b/content/browser/ml/webnn/dml/execution_context.cc index e5ef432163c459..74994e03e73a38 100644 --- a/content/browser/ml/webnn/dml/execution_context.cc +++ b/content/browser/ml/webnn/dml/execution_context.cc @@ -49,15 +49,29 @@ void ExecutionContext::CopyBufferRegion(ID3D12Resource* dest_resource, D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES; transition_barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION; transition_barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE; + + // The resource barrier needs to be before CopyBufferRegion when reading back from the GPU. + if (state == D3D12_RESOURCE_STATE_COPY_SOURCE) + { + command_recorder_.ResourceBarrier({transition_barrier}); + } + command_recorder_.CopyBufferRegion(dest_resource, 0, src_resource, 0, resource_size); D3D12_RESOURCE_BARRIER reset_barrier = transition_barrier; reset_barrier.Transition.StateBefore = state; reset_barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + // TODO: This comment is fishy, because calling CopyBufferRegion before ResourceBarrier + // yields all 0's in the output because Computation hasn't finished yet. Ask Mingming/ // Both transitions to/from UAV can be combined into a single ResourceBarrier // call, and the transition barrier command can be enqueued after // CopyBufferRegion. - command_recorder_.ResourceBarrier({transition_barrier, reset_barrier}); + + // TODO: Verify that can be safely elided when state != D3D12_RESOURCE_STATE_COPY_SOURCE. + // It appears to work fine. + // command_recorder_.ResourceBarrier({transition_barrier, reset_barrier}); + + command_recorder_.ResourceBarrier({reset_barrier}); } HRESULT ExecutionContext::Initialize() { @@ -85,6 +99,7 @@ HRESULT ExecutionContext::ExecuteGraph( IDMLCompiledOperator* compiled_operator, const std::vector& input_bindings, const std::vector& output_bindings) { + DCHECK(compiled_operator != nullptr); return command_recorder_.ExecuteGraph(graph, compiled_operator, input_bindings, output_bindings); } @@ -93,12 +108,22 @@ void ExecutionContext::Flush() const { command_recorder_.CloseAndExecute(); } -void ExecutionContext::WaitForSignal() const { - command_queue_->Wait(); +void ExecutionContext::WaitForSignal(base::OnceClosure callback) { + command_queue_->WaitAsync(base::BindOnce( + &ExecutionContext::OnWaitAsync, base::Unretained(this), std::move(callback))); +} + +void ExecutionContext::OnWaitAsync(base::OnceClosure callback) { // Unlike ID3D12GraphicsCommandList::Reset, it is not recommended to call // Reset on the command allocator while a command list is still being // executed. - command_recorder_.ResetCommandList(); + HRESULT hr = command_recorder_.ResetCommandList(); + if (FAILED(hr)) { + DLOG(ERROR) << "Failed to reset commandList : " + << logging::SystemErrorCodeToString(hr); + return; + } + std::move(callback).Run(); } void ExecutionContext::ReferenceUntilCompleted(ComPtr object) { diff --git a/content/browser/ml/webnn/dml/execution_context.h b/content/browser/ml/webnn/dml/execution_context.h index c4e9d296696613..4c3e213d2a680a 100644 --- a/content/browser/ml/webnn/dml/execution_context.h +++ b/content/browser/ml/webnn/dml/execution_context.h @@ -15,6 +15,7 @@ #include "content/browser/ml/webnn/dml/execution_resources.h" #include "content/browser/ml/webnn/dml/gpgmm_d3d12.h" #include "content/browser/ml/webnn/dml/graph_dml_impl.h" +#include "base/functional/callback_forward.h" namespace content::webnn { @@ -46,8 +47,11 @@ class ExecutionContext final : public base::RefCounted { // Forces all queued work to begin executing on the GPU. void Flush() const; + // Blocks until the current fence is signaled. - void WaitForSignal() const; + void WaitForSignal(base::OnceClosure callback); + void OnWaitAsync(base::OnceClosure callback); + void ReferenceUntilCompleted(ComPtr object); void ReleaseCompletedResources() const; @@ -64,7 +68,7 @@ class ExecutionContext final : public base::RefCounted { ComPtr d3d12_device_; // There is one active command recorder at a time. CommandRecorder command_recorder_; - scoped_refptr command_queue_; + CommandQueue* command_queue_ = nullptr; // ResourceAllocator is owned by adapter ComPtr resource_allocator_; diff --git a/content/browser/ml/webnn/dml/execution_resources.h b/content/browser/ml/webnn/dml/execution_resources.h index 3fe6668cb258f1..ff9d5ff3a33f79 100644 --- a/content/browser/ml/webnn/dml/execution_resources.h +++ b/content/browser/ml/webnn/dml/execution_resources.h @@ -24,7 +24,7 @@ enum class ResourceType { kOutput = 1, kTemporary = 2, kPersistent = 3, - kUnknow = 4, + kUnknown = 4, }; // A unordered resources represent input, output, temporary and persistent diff --git a/content/browser/ml/webnn/dml/gpgmm_d3d12.cpp b/content/browser/ml/webnn/dml/gpgmm_d3d12.cpp index 075beff61c74b3..2b56538b7642ff 100644 --- a/content/browser/ml/webnn/dml/gpgmm_d3d12.cpp +++ b/content/browser/ml/webnn/dml/gpgmm_d3d12.cpp @@ -6,6 +6,9 @@ #include +// TODO::: +#pragma optimize("", off) + #define ReturnIfFailed(expr) \ { \ HRESULT hr = expr; \ @@ -221,8 +224,10 @@ HRESULT ResourceAllocator::CreateAllocator( const ALLOCATOR_DESC& allocatorDescriptor, ResourceAllocator** ppResourceAllocatorOut, ResidencyManager** ppResidencyManagerOut) { + // The device is always required, but the adapter is only required when + // needing a residency manager. if (allocatorDescriptor.Device == nullptr || - allocatorDescriptor.Adapter == nullptr) { + (allocatorDescriptor.Adapter == nullptr && ppResidencyManagerOut != nullptr)) { return E_INVALIDARG; } diff --git a/content/browser/ml/webnn/dml/graph_desc_builder.cc b/content/browser/ml/webnn/dml/graph_desc_builder.cc index 2ea7e0fe439e68..8f863b07408913 100644 --- a/content/browser/ml/webnn/dml/graph_desc_builder.cc +++ b/content/browser/ml/webnn/dml/graph_desc_builder.cc @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "base/logging.h" #include "content/browser/ml/webnn/dml/graph_desc_builder.h" namespace content::webnn { @@ -41,7 +42,7 @@ Node GraphDescBuilder::CreateOperatorNode(DML_OPERATOR_TYPE type, Microsoft::WRL::ComPtr op; HRESULT hr = device_->CreateOperator(&op_desc, IID_PPV_ARGS(&op)); if (FAILED(hr)) { - return {NodeType::kUnknow, 0}; + return {NodeType::kUnknown, 0}; } OperatorNode op_node = {}; @@ -99,8 +100,9 @@ void GraphDescBuilder::AddOutputEdge(NodeOutput* node_output, output_edge.GraphOutputIndex = output_index; graph_desc_.output_edges.push_back(output_edge); - named_outputs_[name] = - node_output->GetTensorDesc().GetTotalTensorSizeInBytes(); + named_outputs_[name] = { + .index = output_index, + .byte_length = node_output->GetTensorDesc().GetTotalTensorSizeInBytes()}; } ComPtr GraphDescBuilder::Compile( @@ -150,7 +152,15 @@ ComPtr GraphDescBuilder::Compile( ComPtr compiled_graph; hr = device1->CompileGraph(&graph_desc, flags, IID_PPV_ARGS(&compiled_graph)); + + // TODO:::DELETE This is just for easy verification that the DML backend was selected. + // Pass: chrome.exe --enable-logging=stderr --log-level=0 + LOG(INFO) << "DML CompileGraph finished"; + // TODO:::END + if (FAILED(hr)) { + LOG(ERROR) << "CompileGraph failed: " + << logging::SystemErrorCodeToString(hr); return nullptr; } return compiled_graph; @@ -160,7 +170,7 @@ std::vector& GraphDescBuilder::GetInputNodes() { return input_nodes_; } -std::map& GraphDescBuilder::GetNamedOutputs() { +std::map& GraphDescBuilder::GetNamedOutputs() { return named_outputs_; } diff --git a/content/browser/ml/webnn/dml/graph_desc_builder.h b/content/browser/ml/webnn/dml/graph_desc_builder.h index a29f7ccf9567f9..2ef4aa0639c347 100644 --- a/content/browser/ml/webnn/dml/graph_desc_builder.h +++ b/content/browser/ml/webnn/dml/graph_desc_builder.h @@ -15,6 +15,11 @@ namespace content::webnn { +struct OutputInfo { + size_t index; + size_t byte_length; +}; + class GraphDescBuilder final { public: explicit GraphDescBuilder(ComPtr device); @@ -31,7 +36,7 @@ class GraphDescBuilder final { ComPtr Compile(DML_EXECUTION_FLAGS flags); std::vector& GetInputNodes(); - std::map& GetNamedOutputs(); + std::map& GetNamedOutputs(); private: struct GraphDesc { @@ -46,12 +51,13 @@ class GraphDescBuilder final { // The inputs node include inputs for execution and constant for // initialization because Both of them are inputs for DirectML Graph. + // The input node index is same as the offset in this vector. std::vector input_nodes_; // The operator nodes hold a reference of IDMLOperator to be used for // GraphDesc.nodes std::vector operator_nodes_; - // The output name and byte length mapping. - std::map named_outputs_; + // The output name to output index and byte length mapping. + std::map named_outputs_; GraphDesc graph_desc_; ComPtr device_; }; diff --git a/content/browser/ml/webnn/dml/graph_dml_impl.cc b/content/browser/ml/webnn/dml/graph_dml_impl.cc index 15d0c7347e7caf..2d924e820f6089 100644 --- a/content/browser/ml/webnn/dml/graph_dml_impl.cc +++ b/content/browser/ml/webnn/dml/graph_dml_impl.cc @@ -4,14 +4,32 @@ #include "content/browser/ml/webnn/dml/graph_dml_impl.h" +#include "base/containers/span.h" #include "base/logging.h" #include "base/memory/ptr_util.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" #include "content/browser/ml/webnn/dml/execution_resources.h" -#include "content/browser/ml/webnn/dml/graph_dml_impl.h" #include "content/browser/ml/webnn/dml/upload_resource.h" #include "mojo/public/c/system/types.h" #include "mojo/public/cpp/bindings/self_owned_receiver.h" +#include "base/task/thread_pool.h" + +#if DML_TARGET_VERSION < 0x4000 + +// The older Windows SDK did not include this operator. +constexpr DML_OPERATOR_TYPE DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR = static_cast(148); + +struct DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC +{ + const DML_TENSOR_DESC* InputTensor; + const DML_TENSOR_DESC* OutputTensor; + const DML_TENSOR_DESC* OutputScaleTensor; + const DML_TENSOR_DESC* OutputZeroPointTensor; +}; + +#endif // DML_TARGET_VERSION < 0x4000 namespace content::webnn { @@ -27,24 +45,92 @@ using ml::webnn::mojom::MemoryInfoPtr; using ml::webnn::mojom::ModelInfoPtr; using ml::webnn::mojom::OperandType; using ml::webnn::mojom::OperationInfo; +using ml::webnn::mojom::PaddingMode; enum TransposeType { NhwcToNchw, NchwToNhwc }; +std::array getLayoutPermutationToNchw(InputOperandLayout original_layout) +{ + // Return indices to gather from the source layout to the target (NCHW). + static_assert(uint32_t(InputOperandLayout::kMaxValue) == 1); + // clang-format off + switch (original_layout) { + case InputOperandLayout::kNchw: return {0, 1, 2, 3}; + case InputOperandLayout::kNhwc: return {0, 3, 1, 2}; + default: + DCHECK(0); + break; + } + // clang-format on +} + +std::array getLayoutPermutationToOihw( + Conv2dFilterOperandLayout filter_layout) { + + static_assert(int32_t(Conv2dFilterOperandLayout::kMaxValue) == 5); + // clang-format off + switch (filter_layout) { + case Conv2dFilterOperandLayout::kOihw: return {0,1,2,3}; + case Conv2dFilterOperandLayout::kIohw: return {1,0,2,3}; + case Conv2dFilterOperandLayout::kHwoi: return {2,3,0,1}; + case Conv2dFilterOperandLayout::kHwio: return {3,2,0,1}; + case Conv2dFilterOperandLayout::kOhwi: return {0,3,1,2}; + case Conv2dFilterOperandLayout::kIhwo: return {3,0,1,2}; + default: + DCHECK(0); + break; + } + // clang-format on +} + +std::array getLayoutPermutationToIohw( + Conv2dFilterOperandLayout filter_layout) { + + static_assert(int32_t(Conv2dFilterOperandLayout::kMaxValue) == 5); + // clang-format off + switch (filter_layout) { + case Conv2dFilterOperandLayout::kOihw: return {1,0,2,3}; + case Conv2dFilterOperandLayout::kIohw: return {0,1,2,3}; + case Conv2dFilterOperandLayout::kHwoi: return {3,2,0,1}; + case Conv2dFilterOperandLayout::kHwio: return {2,3,0,1}; + case Conv2dFilterOperandLayout::kOhwi: return {3,0,1,2}; + case Conv2dFilterOperandLayout::kIhwo: return {0,3,1,2}; + default: + DCHECK(0); + break; + } + // clang-format on +} + std::vector transposeStrides(TransposeType transposeType, const std::vector& input_dims) { - UINT nStride = 0, cStride = 0, hStride = 0, wStride = 0; + /* not needed - filterDims[0]; */ + uint32_t dimension1 = input_dims[1]; + uint32_t dimension2 = input_dims[2]; + uint32_t dimension3 = input_dims[3]; + + uint32_t stride3 = 1; + uint32_t stride2 = dimension3; + uint32_t stride1 = dimension3 * dimension2; + uint32_t stride0 = dimension3 * dimension2 * dimension1; + + uint32_t nStride = 0; + uint32_t cStride = 0; + uint32_t hStride = 0; + uint32_t wStride = 0; + switch (transposeType) { case NhwcToNchw: - nStride = input_dims[1] * input_dims[2] * input_dims[3]; - hStride = input_dims[2] * input_dims[3]; - wStride = input_dims[3]; - cStride = 1; + nStride = stride0; + hStride = stride1; + wStride = stride2; + cStride = stride3; return {nStride, cStride, hStride, wStride}; case NchwToNhwc: - nStride = input_dims[1] * input_dims[2] * input_dims[3]; - cStride = input_dims[2] * input_dims[3]; - hStride = input_dims[3]; - wStride = 1; + nStride = stride0; + cStride = stride1; + hStride = stride2; + wStride = stride3; return {nStride, hStride, wStride, cStride}; default: DCHECK(0); @@ -52,6 +138,19 @@ std::vector transposeStrides(TransposeType transposeType, } } +std::vector transposeStrides(base::span original_strides, base::span permutation) +{ + auto dimension_count = original_strides.size(); + std::vector new_strides; + new_strides.reserve(dimension_count); + for (auto axis : permutation) + { + DCHECK(axis < dimension_count); // This should have already been validated. + new_strides.push_back(original_strides[axis]); + } + return new_strides; +} + std::vector transposeStridesToNchw( const std::vector& input_dims, const DML_TENSOR_DESC* input_tensor_desc) { @@ -66,43 +165,9 @@ std::vector transposeStridesToNchw( } } -DML_OPERATOR_DESC* CreateFusedOperator( - const OperationInfo* activation, - DML_ACTIVATION_LINEAR_OPERATOR_DESC& dmlActicationOperatorDesc, - DML_OPERATOR_DESC& dmlFusedOperatorDesc) { - if (activation == nullptr) { - return nullptr; - } - - dmlActicationOperatorDesc.InputTensor = nullptr; - dmlActicationOperatorDesc.OutputTensor = nullptr; - dmlActicationOperatorDesc.Alpha = 0.0; - dmlActicationOperatorDesc.Beta = 0.0; - switch (activation->which()) { - case OperationInfo::Tag::kRelu: - dmlFusedOperatorDesc.Type = DML_OPERATOR_ACTIVATION_RELU; - break; - case OperationInfo::Tag::kClamp: - return nullptr; - default: - LOG(ERROR) << "This fusion type is not supported."; - DCHECK(0); - } - dmlFusedOperatorDesc.Desc = &dmlActicationOperatorDesc; - return &dmlFusedOperatorDesc; -} - -std::vector ExpandDimensions(const std::vector& dims, size_t rank) { - DCHECK(rank >= dims.size()); - std::vector newDims(rank, 1); - for (size_t i = 0; i < dims.size(); ++i) { - newDims[newDims.size() - i - 1] = dims[dims.size() - i - 1]; - } - return newDims; -} - std::vector transposeDimensions(TransposeType transposeType, const std::vector& input_dims) { + DCHECK(input_dims.size() == 4); std::vector newInputDims(4); switch (transposeType) { case NhwcToNchw: @@ -124,136 +189,154 @@ std::vector transposeDimensions(TransposeType transposeType, return newInputDims; } -std::vector transposeFilterDimensionsAsOihw( - Conv2dFilterOperandLayout filterLayout, - const std::vector& filterDims) { - std::vector newFilterDims(4); - switch (filterLayout) { - case Conv2dFilterOperandLayout::kOhwi: - newFilterDims.resize(4); - newFilterDims[0] = filterDims[0]; - newFilterDims[1] = filterDims[3]; - newFilterDims[2] = filterDims[1]; - newFilterDims[3] = filterDims[2]; - break; - case Conv2dFilterOperandLayout::kHwio: - newFilterDims[0] = filterDims[3]; - newFilterDims[1] = filterDims[2]; - newFilterDims[2] = filterDims[0]; - newFilterDims[3] = filterDims[1]; - break; - case Conv2dFilterOperandLayout::kIhwo: - newFilterDims[0] = filterDims[3]; - newFilterDims[1] = filterDims[0]; - newFilterDims[2] = filterDims[1]; - newFilterDims[3] = filterDims[2]; - break; - default: - DCHECK(0); - break; +DML_OPERATOR_DESC* CreateFusedOperator( + const OperationInfo* activation, + DML_ACTIVATION_LINEAR_OPERATOR_DESC& dmlActicationOperatorDesc, + DML_OPERATOR_DESC& dmlFusedOperatorDesc) { + if (activation == nullptr) { + return nullptr; } - return newFilterDims; -} -std::vector transposeFilterStridesAsOihw( - Conv2dFilterOperandLayout filterLayout, - const std::vector& filterDims) { - UINT hStride = 0, wStride = 0, iStride = 0, oStride = 0; - switch (filterLayout) { - case Conv2dFilterOperandLayout::kHwio: - hStride = filterDims[1] * filterDims[2] * filterDims[3]; - wStride = filterDims[2] * filterDims[3]; - iStride = filterDims[3]; - oStride = 1; - break; - case Conv2dFilterOperandLayout::kOhwi: - oStride = filterDims[1] * filterDims[2] * filterDims[3]; - hStride = filterDims[2] * filterDims[3]; - wStride = filterDims[3]; - iStride = 1; - break; - case Conv2dFilterOperandLayout::kIhwo: - iStride = filterDims[1] * filterDims[2] * filterDims[3]; - hStride = filterDims[2] * filterDims[3]; - wStride = filterDims[3]; - oStride = 1; + dmlActicationOperatorDesc.InputTensor = nullptr; + dmlActicationOperatorDesc.OutputTensor = nullptr; + dmlActicationOperatorDesc.Alpha = 0.0; + dmlActicationOperatorDesc.Beta = 0.0; + switch (activation->which()) { + case OperationInfo::Tag::kRelu: + dmlFusedOperatorDesc.Type = DML_OPERATOR_ACTIVATION_RELU; break; + case OperationInfo::Tag::kClamp: + return nullptr; default: + LOG(ERROR) << "This fusion type is not supported."; DCHECK(0); - break; } - return {oStride, iStride, hStride, wStride}; + dmlFusedOperatorDesc.Desc = &dmlActicationOperatorDesc; + return &dmlFusedOperatorDesc; +} + +uint16_t CastFloat32ToFloat16(float float32_value) noexcept { + static uint32_t constexpr float16_mantissa_count = 10; + static int32_t constexpr float32to16_mantissa_count_difference = 23 - 10; + static int32_t constexpr float32vs16_exponent_adjustment = 127 - 15; + static uint32_t constexpr float16_sign_mask = 0b1'00000'0000000000; + static uint32_t constexpr float16_mantissa_mask = 0b0'00000'1111111111; + static uint32_t constexpr float16_exponent_mask = 0b0'11111'0000000000; + static uint32_t constexpr float16_mantissa_and_exponentMask = 0b0'11111'1111111111; + static uint32_t constexpr float32_mantissa_and_exponent_mask = 0b01111111'10000000'00000000'00000000; + + // Shift the mantissa, exponent, and sign from the 32-bit locations to 16-bit. + // Sature the exponent if greater than float16 can represent. + // float32 denorms are flushed to zero. + + uint32_t const float32_bit_value = reinterpret_cast(float32_value); + uint32_t const sign = (float32_bit_value >> 16) & float16_sign_mask; + int32_t const float32_mantissa_and_exponent = + float32_bit_value & float32_mantissa_and_exponent_mask; + int32_t const float16_mantissa_and_exponent = + (float32_mantissa_and_exponent >> float32to16_mantissa_count_difference) - + (float32vs16_exponent_adjustment << float16_mantissa_count); + uint32_t const float16_denorm_mask = + (float16_mantissa_and_exponent > int32_t(float16_mantissa_mask)) + ? float16_mantissa_and_exponentMask + : 0; + uint32_t const float16_saturation_ask = + (float16_mantissa_and_exponent >= int32_t(float16_mantissa_and_exponentMask)) + ? float16_exponent_mask + : 0; + uint32_t const float16_bit_value = + (float16_mantissa_and_exponent & float16_denorm_mask) | float16_saturation_ask | + sign; + return uint16_t(float16_bit_value); } DML_TENSOR_DATA_TYPE GetTensorDataType(OperandType type) { - DML_TENSOR_DATA_TYPE data_type; - if (type == OperandType::kFloat32) { - data_type = DML_TENSOR_DATA_TYPE_FLOAT32; - } else if (type == OperandType::kFloat16) { - data_type = DML_TENSOR_DATA_TYPE_FLOAT16; - } else if (type == OperandType::kInt32) { - data_type = DML_TENSOR_DATA_TYPE_INT32; - } else if (type == OperandType::kUint32) { - data_type = DML_TENSOR_DATA_TYPE_UINT32; - } else { + // clang-format off + switch (type) + { + case OperandType::kFloat32: return DML_TENSOR_DATA_TYPE_FLOAT32; + case OperandType::kFloat16: return DML_TENSOR_DATA_TYPE_FLOAT16; + case OperandType::kInt8: return DML_TENSOR_DATA_TYPE_INT8; + case OperandType::kUint8: return DML_TENSOR_DATA_TYPE_UINT8; + case OperandType::kInt32: return DML_TENSOR_DATA_TYPE_INT32; + case OperandType::kUint32: return DML_TENSOR_DATA_TYPE_UINT32; + case OperandType::kInt64: return DML_TENSOR_DATA_TYPE_INT64; + case OperandType::kUint64: return DML_TENSOR_DATA_TYPE_UINT64; + default: LOG(ERROR) << "This data type is not supported"; return DML_TENSOR_DATA_TYPE_UNKNOWN; } - - return data_type; + // clang-format on } -// Strides are used to express broadcasting (by specifying a stride of 0) as -// well as padding. If Strides is not specified, each dimension in the tensor is -// considered to be contiguously packed, with no additional padding. The -// calculated strides refer to -// https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml-helper-functions#calculatestrides -std::vector CalculateStridesForBroadcast( - NodeOutput* node_output, - std::vector broadcasted_dims) { - auto& tensor_desc = node_output->GetTensorDesc(); - auto original_dims = tensor_desc.GetDimensions(); - auto original_rank = original_dims.size(), - broadcasted_rank = broadcasted_dims.size(); - std::vector broadcast_flags(broadcasted_rank, false); - auto rank_gap = broadcasted_rank - original_rank; - for (size_t i = 0; i < rank_gap; ++i) { - broadcast_flags[i] = true; - } - for (size_t i = 0; i < original_rank; ++i) { - if (original_dims[i] == 1 && broadcasted_dims[rank_gap + i] != 1) { - broadcast_flags[rank_gap + i] = true; - } +DML_SCALAR_UNION GetScalarUnion(DML_TENSOR_DATA_TYPE tensorDataType, float value) +{ + DML_SCALAR_UNION valueUnion = {}; + + // clang-format off + switch (tensorDataType) + { + case DML_TENSOR_DATA_TYPE_FLOAT32: valueUnion.Float32 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_FLOAT16: valueUnion.UInt16 = CastFloat32ToFloat16(value); break; + case DML_TENSOR_DATA_TYPE_UINT32: valueUnion.UInt32 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT16: valueUnion.UInt16 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT8: valueUnion.UInt8 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT32: valueUnion.Int32 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT16: valueUnion.Int16 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT8: valueUnion.Int8 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_FLOAT64: valueUnion.Float64 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT64: valueUnion.UInt64 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT64: valueUnion.Int64 = static_cast(value); break; + case DML_TENSOR_DATA_TYPE_UNKNOWN: /* keep zeroed */ break; + default: /* keep zeroed */ break; } + // clang-format on - for (size_t i = 0; i < broadcasted_rank; ++i) { - if (broadcast_flags[i]) { - broadcasted_dims[i] = 1; - } + return valueUnion; +} + +DML_REDUCE_FUNCTION MapOperatorTypeToReductionFuntion(OperatorType operator_type) { + // clang-format off + switch (operator_type) + { + case OperatorType::kReduceL1: return DML_REDUCE_FUNCTION_L1; + case OperatorType::kReduceL2: return DML_REDUCE_FUNCTION_L2; + case OperatorType::kReduceLogSum: return DML_REDUCE_FUNCTION_LOG_SUM; + case OperatorType::kReduceLogSumExp: return DML_REDUCE_FUNCTION_LOG_SUM_EXP; + case OperatorType::kReduceMax: return DML_REDUCE_FUNCTION_MAX; + case OperatorType::kReduceMean: return DML_REDUCE_FUNCTION_AVERAGE; + case OperatorType::kReduceMin: return DML_REDUCE_FUNCTION_MIN; + case OperatorType::kReduceProduct: return DML_REDUCE_FUNCTION_MULTIPLY; + case OperatorType::kReduceSum: return DML_REDUCE_FUNCTION_SUM; + case OperatorType::kReduceSumSquare: return DML_REDUCE_FUNCTION_SUM_SQUARE; + default: + LOG(ERROR) << "This operator type is not supported for reduction"; + return DML_REDUCE_FUNCTION_MIN; } - std::vector strides(broadcasted_rank); - auto existed_strides = tensor_desc.GetStrides(); - if (existed_strides) { - auto indexBegin = broadcasted_rank - original_rank; - for (size_t i = 0, j = 0; i < broadcasted_rank; ++i) { - if (i < indexBegin) { - strides[i] = 0; - } else { - strides[i] = broadcast_flags[i] ? 0 : existed_strides.value()[j]; - ++j; - } - } - } else { - strides[broadcasted_rank - 1] = - broadcast_flags[broadcasted_rank - 1] ? 0 : 1; - size_t elements = 1; - for (size_t i = 1; i < broadcasted_rank; i++) { - size_t j = broadcasted_rank - i - 1; - elements *= broadcasted_dims[j + 1]; - strides[j] = broadcast_flags[j] ? 0 : elements; - } + // clang-format on +} + +DML_PADDING_MODE MapPaddingModeToDml(PaddingMode operator_type) { + // clang-format off + switch (operator_type) + { + case PaddingMode::kConstant: return DML_PADDING_MODE_CONSTANT; + case PaddingMode::kEdge: return DML_PADDING_MODE_EDGE; + case PaddingMode::kReflection: return DML_PADDING_MODE_REFLECTION; + case PaddingMode::kSymmetric: return DML_PADDING_MODE_SYMMETRIC; + default: + LOG(ERROR) << "This padding mode is not supported"; + return DML_PADDING_MODE_CONSTANT; } - return strides; + // clang-format on +} + +TensorDesc GetBroadcastedTensorDesc(NodeOutput* input_node, + base::span broadcasted_dims, + uint32_t ignorable_tail_count = 0) { + TensorDesc broadcasted_tensor(input_node->GetTensorDesc()); + broadcasted_tensor.BroadcastTo(broadcasted_dims, TensorDesc::Alignment::kTrailing, ignorable_tail_count); + return broadcasted_tensor; } } // namespace @@ -267,31 +350,22 @@ std::vector CalculateStridesForBroadcast( } while (0) #define CREATE_BINARY_OPERATOR(type, a_tensor_desc, b_tensor_desc, \ - output_tensor, node) \ - DML_ELEMENT_WISE_##type##_OPERATOR_DESC operator_desc{}; \ + output_tensor_desc, node) \ + DML_##type##_OPERATOR_DESC operator_desc{}; \ operator_desc.ATensor = a_tensor_desc; \ operator_desc.BTensor = b_tensor_desc; \ - operator_desc.OutputTensor = output_tensor; \ + operator_desc.OutputTensor = output_tensor_desc; \ node = graph_desc_builder_->CreateOperatorNode( \ - DML_OPERATOR_ELEMENT_WISE_##type, &operator_desc); + DML_OPERATOR_##type, &operator_desc); -#define CREATE_UNARY_OPERATOR(type, input_tensor_desc, node) \ +#define CREATE_UNARY_OPERATOR(type, input_tensor_desc, \ + output_tensor_desc, node) \ DML_##type##_OPERATOR_DESC operator_desc{}; \ operator_desc.InputTensor = input_tensor_desc; \ - operator_desc.OutputTensor = input_tensor_desc; \ + operator_desc.OutputTensor = output_tensor_desc; \ node = graph_desc_builder_->CreateOperatorNode(DML_OPERATOR_##type, \ &operator_desc); -// Append IDENTITY to remove the strides of input tensor. Use this to implement -// Reshape, Squeeze, Transpose and avoid creating an invalid graph with input = -// output. -#define APPEND_IDENTITY(input_tensor, output_tensor, node) \ - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operator_desc{}; \ - operator_desc.InputTensor = input_tensor; \ - operator_desc.OutputTensor = output_tensor; \ - node = graph_desc_builder_->CreateOperatorNode( \ - DML_OPERATOR_ELEMENT_WISE_IDENTITY, &operator_desc); - // static void GraphDMLImpl::Create(mojo::PendingReceiver receiver, scoped_refptr execution_context) { @@ -338,69 +412,267 @@ void GraphDMLImpl::AddConstant(OperandDescriptorPtr desc, UINT64 index) { return; } -void GraphDMLImpl::AddElementWiseBinary(UINT64 a_index, +void GraphDMLImpl::AddElementWiseUnary(OperatorType operator_type, + UINT64 input_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + auto* input_node_output = node_output_map_[input_index].get(); + auto output_dims = output_desc->dimensions; + + auto& input_tensor_desc = input_node_output->GetTensorDesc(); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dims); + Node node; + + switch (operator_type) { + case OperatorType::kAbs: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_ABS, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kNeg: { + // Can't just say CREATE_UNARY_OPERATOR(ELEMENT_WISE_NEGATE...) as DML + // feature level 5.0 has a proper DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC + // that works properly for all data types (including int64), but the + // older Win11 DML lacks it. So use the work-around of multiplying by + // identity with a scale factor. + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operator_desc = {}; + DML_SCALE_BIAS scale_bias = {.Scale = -1.0f, .Bias = 0}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.ScaleBias = &scale_bias; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ELEMENT_WISE_IDENTITY, + &operator_desc); + } break; + case OperatorType::kCos: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_COS, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kErf: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_ERF, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kExp: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_EXP, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kLog: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_LOG, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kHardSwish: { + DAWN_INTERNAL_ERROR("Hard swish op is not implemented."); + // // TODO:::Implement - Compose from smaller ops: + // // x * max(0, min(6, (x + 3))) / 6 + // // x * clamp(x + 3, 0, 6) / 6 + // CREATE_UNARY_OPERATOR(..., input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kIdentity: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_IDENTITY, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kSin: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_SIN, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kTan: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_TAN, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kSqrt: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_SQRT, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kSigmoid: { + CREATE_UNARY_OPERATOR(ACTIVATION_SIGMOID, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kRelu: { + CREATE_UNARY_OPERATOR(ACTIVATION_RELU, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kFloor: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_FLOOR, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kCeil: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_CEIL, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kReciprocal: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_RECIP, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kLogicalNot: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_LOGICAL_NOT, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kTanh: { + CREATE_UNARY_OPERATOR(ELEMENT_WISE_TANH, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + case OperatorType::kSoftsign: { + CREATE_UNARY_OPERATOR(ACTIVATION_SOFTSIGN, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + } break; + + default: + DAWN_INTERNAL_ERROR("Unary elementwise op is not implemented."); + } + + graph_desc_builder_->Connect({input_node_output}, {node}); + auto node_output = + graph_desc_builder_->CreateNodeOutput(node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); + return; +} + +void GraphDMLImpl::AddElementWiseUnaryTwoParameter( + OperatorType operator_type, + UINT64 input_index, + float first_parameter, + float second_parameter, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + auto* input_node_output = node_output_map_[input_index].get(); + auto output_dims = output_desc->dimensions; + + auto& input_tensor_desc = input_node_output->GetTensorDesc(); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dims); + Node node; + + switch (operator_type) { + case OperatorType::kElu: { + DML_ACTIVATION_ELU_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Alpha = first_parameter; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_ELU, + &operator_desc); + } break; + case OperatorType::kLeakyRelu: { + DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Alpha = first_parameter; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_LEAKY_RELU, + &operator_desc); + } break; + case OperatorType::kLinear: { + DML_ACTIVATION_LINEAR_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Alpha = first_parameter; + operator_desc.Beta = second_parameter; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_LINEAR, + &operator_desc); + } break; + case OperatorType::kHardSigmoid: { + DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Alpha = first_parameter; + operator_desc.Beta = second_parameter; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_HARD_SIGMOID, + &operator_desc); + } break; + case OperatorType::kSoftplus: { + DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Steepness = first_parameter; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_SOFTPLUS, + &operator_desc); + } break; + + default: + DAWN_INTERNAL_ERROR("Unary elementwise op is not implemented."); + } + + graph_desc_builder_->Connect({input_node_output}, {node}); + auto node_output = + graph_desc_builder_->CreateNodeOutput(node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); + return; +} + +void GraphDMLImpl::AddElementWiseBinary(OperatorType operator_type, + UINT64 a_index, UINT64 b_index, - ElementWiseBinaryType type, OperandDescriptorPtr output_desc, UINT64 output_index) { // TODO: return directly if BuildResult has error message. - DCHECK(node_output_map_.find(a_index) != node_output_map_.end()); - DCHECK(node_output_map_.find(b_index) != node_output_map_.end()); + DCHECK(node_output_map_.contains(a_index)); + DCHECK(node_output_map_.contains(b_index)); auto* a_node_output = node_output_map_[a_index].get(); auto* b_node_output = node_output_map_[b_index].get(); auto output_dims = output_desc->dimensions; - std::vector output_new_dims = output_dims; - - auto a_broadcasted_strides = - CalculateStridesForBroadcast(a_node_output, output_dims); - auto& a_tensor_desc = a_node_output->GetTensorDesc(); - TensorDesc a_broadcasted_tensor(a_tensor_desc.GetDataType(), - a_tensor_desc.GetFlags(), output_dims, - a_broadcasted_strides); - - auto b_broadcasted_strides = - CalculateStridesForBroadcast(b_node_output, output_dims); - auto& b_tensor_desc = b_node_output->GetTensorDesc(); - TensorDesc b_broadcasted_tensor(b_tensor_desc.GetDataType(), - b_tensor_desc.GetFlags(), output_dims, - b_broadcasted_strides); - - TensorDesc output_tensor(b_tensor_desc.GetDataType(), output_new_dims); + + TensorDesc a_broadcasted_tensor = + GetBroadcastedTensorDesc(a_node_output, output_dims); + TensorDesc b_broadcasted_tensor = + GetBroadcastedTensorDesc(b_node_output, output_dims); + + TensorDesc output_tensor(GetTensorDataType(output_desc->data_type), output_dims); Node node; - switch (type) { - case ElementWiseBinaryType::kAdd: { - CREATE_BINARY_OPERATOR(ADD, a_broadcasted_tensor.Get(), + switch (operator_type) { + case OperatorType::kAdd: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_ADD, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; - case ElementWiseBinaryType::kDiv: { - CREATE_BINARY_OPERATOR(DIVIDE, a_broadcasted_tensor.Get(), + case OperatorType::kDiv: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_DIVIDE, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; - case ElementWiseBinaryType::kMul: { - CREATE_BINARY_OPERATOR(MULTIPLY, a_broadcasted_tensor.Get(), + case OperatorType::kMul: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_MULTIPLY, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; - case ElementWiseBinaryType::kSub: { - CREATE_BINARY_OPERATOR(SUBTRACT, a_broadcasted_tensor.Get(), + case OperatorType::kSub: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_SUBTRACT, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; - case ElementWiseBinaryType::kMax: { - CREATE_BINARY_OPERATOR(MAX, a_broadcasted_tensor.Get(), + case OperatorType::kMax: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_MAX, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; - case ElementWiseBinaryType::kMin: { - CREATE_BINARY_OPERATOR(MIN, a_broadcasted_tensor.Get(), + case OperatorType::kMin: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_MIN, a_broadcasted_tensor.Get(), b_broadcasted_tensor.Get(), output_tensor.Get(), node); } break; + case OperatorType::kEqual: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_LOGICAL_EQUALS, a_broadcasted_tensor.Get(), + b_broadcasted_tensor.Get(), output_tensor.Get(), + node); + } break; + case OperatorType::kGreater: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_LOGICAL_GREATER_THAN, a_broadcasted_tensor.Get(), + b_broadcasted_tensor.Get(), output_tensor.Get(), + node); + } break; + case OperatorType::kLesser: { + CREATE_BINARY_OPERATOR(ELEMENT_WISE_LOGICAL_LESS_THAN, a_broadcasted_tensor.Get(), + b_broadcasted_tensor.Get(), output_tensor.Get(), + node); + } break; + case OperatorType::kPow: { + DML_ELEMENT_WISE_POW_OPERATOR_DESC operator_desc{}; + operator_desc.InputTensor = a_broadcasted_tensor.Get(); + operator_desc.ExponentTensor = b_broadcasted_tensor.Get(); + operator_desc.OutputTensor = output_tensor.Get(); + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ELEMENT_WISE_POW, &operator_desc); + } break; + case OperatorType::kPrelu: { + DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC operator_desc{}; + operator_desc.InputTensor = a_broadcasted_tensor.Get(); + operator_desc.SlopeTensor = b_broadcasted_tensor.Get(); + operator_desc.OutputTensor = output_tensor.Get(); + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &operator_desc); + } break; + default: - DAWN_INTERNAL_ERROR(" Binary op is not implemented."); + DAWN_INTERNAL_ERROR("Binary elementwise op is not implemented."); } graph_desc_builder_->Connect({a_node_output, b_node_output}, {node}); auto node_output = @@ -469,84 +741,112 @@ void GraphDMLImpl::TransposeOutputToNhwc( return; } -void GraphDMLImpl::AddConv2d(UINT64 input_index, +void GraphDMLImpl::AddConv2d(OperatorType operator_type, + UINT64 input_index, UINT64 filter_index, + UINT64 input_zero_point_index, + UINT64 filter_zero_point_index, Conv2dOptionsPtr options, OperandDescriptorPtr output_desc, UINT64 output_index) { // TODO: return directly if BuildResult has error message. - DCHECK(node_output_map_.find(input_index) != node_output_map_.end()); - DCHECK(node_output_map_.find(filter_index) != node_output_map_.end()); + DCHECK(node_output_map_.contains(input_index)); + DCHECK(node_output_map_.contains(filter_index)); + + bool is_backward_direction = + (operator_type == OperatorType::kConvTranspose2d); auto* input_node = node_output_map_[input_index].get(); auto* filter_node = node_output_map_[filter_index].get(); - - auto& input_node_desc = input_node->GetTensorDesc(); - auto input_dims = input_node_desc.GetDimensions(); - auto filterDims = filter_node->GetTensorDesc().GetDimensions(); - auto output_dims = output_desc->dimensions; - std::vector input_nchw_dims = input_dims, filter_nchw_dims = filterDims, - output_nchw_dims = output_dims; - - DML_TENSOR_DESC* input_tensor_desc = input_node_desc.Get(); - TensorDesc nhwc_tensor_desc; - if (options->inputLayout == InputOperandLayout::kNhwc) { - input_nchw_dims = transposeDimensions(NhwcToNchw, input_dims); - output_nchw_dims = transposeDimensions(NhwcToNchw, output_dims); - auto input_nchw_Strides = - transposeStridesToNchw(input_dims, input_tensor_desc); - - nhwc_tensor_desc = - TensorDesc(input_node_desc.GetDataType(), input_node_desc.GetFlags(), - input_nchw_dims, input_nchw_Strides); - input_tensor_desc = nhwc_tensor_desc.Get(); + TensorDesc& input_node_tensor_desc = input_node->GetTensorDesc(); + TensorDesc& filter_node_tensor_desc = filter_node->GetTensorDesc(); + TensorDesc output_node_tensor_desc( + GetTensorDataType(output_desc->data_type), output_desc->dimensions); + + std::vector& input_dims = input_node_tensor_desc.GetDimensions(); + std::vector& filterDims = filter_node_tensor_desc.GetDimensions(); + std::vector& output_dims = output_node_tensor_desc.GetDimensions(); + std::vector input_nchw_dims = input_dims; + std::vector filter_nchw_dims = filterDims; + + DML_TENSOR_DESC* input_tensor_desc = input_node_tensor_desc.Get(); + DML_TENSOR_DESC* output_tensor_desc = output_node_tensor_desc.Get(); + + TensorDesc nhwc_input_tensor_desc; + TensorDesc nhwc_output_tensor_desc; + if (options->inputLayout != InputOperandLayout::kNchw) { + nhwc_input_tensor_desc = input_node_tensor_desc; + nhwc_output_tensor_desc = output_node_tensor_desc; + + std::array permutation = + getLayoutPermutationToNchw(options->inputLayout); + nhwc_input_tensor_desc.PermuteDimensions(permutation, + TensorDesc::Alignment::kTrailing); + nhwc_output_tensor_desc.PermuteDimensions(permutation, + TensorDesc::Alignment::kTrailing); + input_nchw_dims = nhwc_input_tensor_desc.GetDimensions(); + input_tensor_desc = nhwc_input_tensor_desc.Get(); + output_tensor_desc = nhwc_output_tensor_desc.Get(); } + // convTranspose uses IOHW filter layout, where conv uses OIHW. + Conv2dFilterOperandLayout desired_filter_layout = is_backward_direction + ? Conv2dFilterOperandLayout::kIohw + : Conv2dFilterOperandLayout::kOihw; + DML_TENSOR_DESC* filter_tensor_desc = filter_node->GetTensorDesc().Get(); TensorDesc new_filter_tensor_desc; - if (options->filterLayout != Conv2dFilterOperandLayout::kOihw) { - filter_nchw_dims = - transposeFilterDimensionsAsOihw(options->filterLayout, filterDims); - auto filter_oihw_strides = - transposeFilterStridesAsOihw(options->filterLayout, filterDims); - - auto& fileter_desc = filter_node->GetTensorDesc(); - new_filter_tensor_desc = - TensorDesc(fileter_desc.GetDataType(), fileter_desc.GetFlags(), - filter_nchw_dims, filter_oihw_strides); + if (options->filterLayout != desired_filter_layout) { + new_filter_tensor_desc = filter_node->GetTensorDesc(); + std::array permutation = + is_backward_direction + ? getLayoutPermutationToIohw(options->filterLayout) + : getLayoutPermutationToOihw(options->filterLayout); + new_filter_tensor_desc.PermuteDimensions(permutation, + TensorDesc::Alignment::kTrailing); + filter_nchw_dims = new_filter_tensor_desc.GetDimensions(); filter_tensor_desc = new_filter_tensor_desc.Get(); } std::vector input_nodes = {input_node, filter_node}; TensorDesc bias_tensor_desc; if (options->bias_index != std::numeric_limits::max()) { - DCHECK(node_output_map_.find(options->bias_index) != - node_output_map_.end()); + // Read the bias tensor desc. + DCHECK(node_output_map_.contains(options->bias_index)); auto* bias_node = node_output_map_[options->bias_index].get(); - auto& bias_desc = bias_node->GetTensorDesc(); - auto bias_dims = bias_desc.GetDimensions(); - if (bias_dims[0] != filter_nchw_dims[0] || bias_dims.size() != 1) { + auto& original_bias_desc = bias_node->GetTensorDesc(); + auto& original_bias_dims = original_bias_desc.GetDimensions(); + bias_tensor_desc = original_bias_desc; + + // Sanity check the bias shape against the output shape. + uint32_t output_channel_count = + is_backward_direction ? filter_nchw_dims[1] : filter_nchw_dims[0]; + if (original_bias_dims.size() != 1 || + original_bias_dims[0] != output_channel_count) { DAWN_INTERNAL_ERROR( "The bias should be 1-D tensor with the shape of [output_channels]."); } - // Reshape bias from 1-D to 4-D for NCHW layout. - std::vector bias_expand_dims = {1, bias_dims[0], 1, 1}; - bias_tensor_desc = TensorDesc(bias_desc.GetDataType(), bias_desc.GetFlags(), - bias_expand_dims); + // Reshape bias from 1-D to 4-D for NCHW layout, moving channel to axis 1. + constexpr std::array aligned_channel_permutation = {0,3,0,0}; + bias_tensor_desc.PermuteDimensions(aligned_channel_permutation, + TensorDesc::Alignment::kTrailing); input_nodes.push_back(bias_node); } - // FIXME(nhu): strides, dilations, padding should be uint32_t - // need to fix the spec. std::vector strides = options->strides; std::vector dilations = options->dilations; + base::span input_nchw_dims_span(input_nchw_dims); + base::span filter_nchw_dims_span(filter_nchw_dims); std::vector padding = options->auto_pad == AutoPad::kExplicit ? ExplicitPadding(options.get()) - : ImplicitPadding(options.get(), input_nchw_dims, - filter_nchw_dims); + : ImplicitPadding( + options.get(), + input_nchw_dims_span, + filter_nchw_dims_span + ); std::vector startPadding = {padding[0], padding[2]}; std::vector endPadding = {padding[1], padding[3]}; std::vector defaultOutPadding = {0, 0}; @@ -557,35 +857,65 @@ void GraphDMLImpl::AddConv2d(UINT64 input_index, CreateFusedOperator(options->activation.get(), dmlActicationOperatorDesc, dmlFusedOperatorDesc); - TensorDesc output_tensor(input_node_desc.GetDataType(), output_nchw_dims); - DML_CONVOLUTION_OPERATOR_DESC operator_desc{}; - operator_desc.InputTensor = input_tensor_desc; - operator_desc.FilterTensor = filter_tensor_desc; - operator_desc.BiasTensor = bias_tensor_desc.Get(); - operator_desc.OutputTensor = output_tensor.Get(); + Node operator_node; - operator_desc.Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION; - operator_desc.Direction = DML_CONVOLUTION_DIRECTION_FORWARD; - operator_desc.DimensionCount = input_dims.size() - 2; - operator_desc.Strides = strides.data(); - operator_desc.Dilations = dilations.data(); - operator_desc.StartPadding = startPadding.data(); - operator_desc.EndPadding = endPadding.data(); - operator_desc.OutputPadding = defaultOutPadding.data(); - operator_desc.GroupCount = static_cast(options->groups); - operator_desc.FusedActivation = fusedActivation; + std::unique_ptr output_node; + if (operator_type == OperatorType::kConv2dInteger) { + auto* input_zero_point_node = + node_output_map_[input_zero_point_index].get(); + auto* filter_zero_point_node = + node_output_map_[filter_zero_point_index].get(); + TensorDesc& input_zero_point_tensor_desc = + input_zero_point_node->GetTensorDesc(); + TensorDesc& filter_zero_point_tensor_desc = + filter_zero_point_node->GetTensorDesc(); + + input_nodes.insert(input_nodes.begin() + 1, input_zero_point_node); + input_nodes.push_back(filter_zero_point_node); + + DML_CONVOLUTION_INTEGER_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc; + operator_desc.InputZeroPointTensor = input_zero_point_tensor_desc.Get(); + operator_desc.FilterTensor = filter_tensor_desc; + operator_desc.FilterZeroPointTensor = filter_zero_point_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc; + operator_desc.DimensionCount = input_dims.size() - 2; + operator_desc.Strides = strides.data(); + operator_desc.Dilations = dilations.data(); + operator_desc.StartPadding = startPadding.data(); + operator_desc.EndPadding = endPadding.data(); + operator_desc.GroupCount = static_cast(options->groups); - Node operator_node = graph_desc_builder_->CreateOperatorNode( - DML_OPERATOR_CONVOLUTION, &operator_desc); - graph_desc_builder_->Connect(std::move(input_nodes), operator_node); - auto output_node = graph_desc_builder_->CreateNodeOutput( - operator_node, 0, std::move(output_tensor)); + operator_node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_CONVOLUTION_INTEGER, &operator_desc); + } else { + DML_CONVOLUTION_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc; + operator_desc.FilterTensor = filter_tensor_desc; + operator_desc.BiasTensor = bias_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc; + + operator_desc.Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION; + operator_desc.Direction = is_backward_direction + ? DML_CONVOLUTION_DIRECTION_BACKWARD + : DML_CONVOLUTION_DIRECTION_FORWARD; + operator_desc.DimensionCount = input_dims.size() - 2; + operator_desc.Strides = strides.data(); + operator_desc.Dilations = dilations.data(); + operator_desc.StartPadding = startPadding.data(); + operator_desc.EndPadding = endPadding.data(); + operator_desc.OutputPadding = defaultOutPadding.data(); + operator_desc.GroupCount = static_cast(options->groups); + operator_desc.FusedActivation = fusedActivation; - // Transpose output from nchw->nhwc. - if (options->inputLayout == InputOperandLayout::kNhwc) { - TransposeOutputToNhwc(output_node, output_nchw_dims); + operator_node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_CONVOLUTION, &operator_desc); } + graph_desc_builder_->Connect(std::move(input_nodes), operator_node); + output_node = graph_desc_builder_->CreateNodeOutput( + operator_node, 0, std::move(output_node_tensor_desc)); + EmulateFusedOperator(options->activation.get(), output_node, output_dims); node_output_map_[output_index] = std::move(output_node); return; @@ -610,8 +940,11 @@ void GraphDMLImpl::AddReshape(UINT64 input_index, return; } -void GraphDMLImpl::AddGemm(UINT64 a_index, +void GraphDMLImpl::AddGemm(OperatorType operator_type, + UINT64 a_index, UINT64 b_index, + UINT64 a_index_zero_point, + UINT64 b_index_zero_point, GemmOptionsPtr options, OperandDescriptorPtr output_desc, UINT64 output_index) { @@ -619,77 +952,82 @@ void GraphDMLImpl::AddGemm(UINT64 a_index, DCHECK(node_output_map_.find(a_index) != node_output_map_.end()); DCHECK(node_output_map_.find(b_index) != node_output_map_.end()); - // The shape of a tensor is 2D definited in WebNN Spec, but DML only support - // 4D, so expand dimensions to 4D. - // TODO: DML_FEATURE_LEVEL_4_0 and above support 2D. - // DCHECK(a_dims.size() == 2); auto* a_node_output = node_output_map_[a_index].get(); - auto& a_tensor_desc = a_node_output->GetTensorDesc(); - auto a_expand_dims = ExpandDimensions(a_tensor_desc.GetDimensions(), 4); - TensorDesc a_expand_tensor(a_tensor_desc.GetDataType(), - a_tensor_desc.GetFlags(), a_expand_dims); - - // DCHECK(b_dims.size() == 2); auto* b_node_output = node_output_map_[b_index].get(); - auto& b_tensor_desc = b_node_output->GetTensorDesc(); - auto b_expand_dims = ExpandDimensions(b_tensor_desc.GetDimensions(), 4); - TensorDesc b_expand_tensor(b_tensor_desc.GetDataType(), - b_tensor_desc.GetFlags(), b_expand_dims); + auto& output_dims = output_desc->dimensions; - auto output_dims = output_desc->dimensions; - DCHECK(output_dims.size() == 2); - auto output_expand_dims = ExpandDimensions(output_dims, 4); - TensorDesc output_expand_tensor(b_tensor_desc.GetDataType(), - output_expand_dims); + TensorDesc a_tensor_desc = + GetBroadcastedTensorDesc(a_node_output, output_dims, 2); + TensorDesc b_tensor_desc = + GetBroadcastedTensorDesc(b_node_output, output_dims, 2); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dims); + + DCHECK(a_tensor_desc.GetDimensions().size() == b_tensor_desc.GetDimensions().size()); + DCHECK(b_tensor_desc.GetDimensions().size() == output_tensor_desc.GetDimensions().size()); // The operand c is optional. - TensorDesc c_expand_tensor; + TensorDesc c_tensor_desc; std::vector input_nodes = {a_node_output, b_node_output}; if (options->c_index != std::numeric_limits::max()) { DCHECK(node_output_map_.find(options->c_index) != node_output_map_.end()); auto* c_node_output = node_output_map_[options->c_index].get(); - // It is either a scalar, or of the shape that is unidirectionally - // broadcastable to the shape [M, N] definited in WebNN Spec, DML only - // support 4D, so broadCast the Shape of optional C to {1, 1, M, N } - // supported in DML. - auto c_broadcasted_strides = - CalculateStridesForBroadcast(c_node_output, output_expand_dims); - auto& c_tensor_desc = c_node_output->GetTensorDesc(); - c_expand_tensor = - TensorDesc(c_tensor_desc.GetDataType(), c_tensor_desc.GetFlags(), - output_expand_dims, c_broadcasted_strides); + + // Broadcast C's shape up to the output rank for DML. It enters as either + // a scalar or as a shape that is unidirectionally broadcastable to the + // shape [M, N] as defined in WebNN Spec. + c_tensor_desc = GetBroadcastedTensorDesc(c_node_output, output_dims); input_nodes.push_back(c_node_output); } - DML_MATRIX_TRANSFORM aTranspose = options->a_transpose - ? DML_MATRIX_TRANSFORM_TRANSPOSE - : DML_MATRIX_TRANSFORM_NONE; - DML_MATRIX_TRANSFORM bTranspose = options->b_transpose - ? DML_MATRIX_TRANSFORM_TRANSPOSE - : DML_MATRIX_TRANSFORM_NONE; - DML_GEMM_OPERATOR_DESC gemm_desc = {}; - gemm_desc.ATensor = a_expand_tensor.Get(); - gemm_desc.BTensor = b_expand_tensor.Get(); - gemm_desc.CTensor = c_expand_tensor.Get(); - gemm_desc.OutputTensor = output_expand_tensor.Get(); - gemm_desc.TransA = aTranspose; - gemm_desc.TransB = bTranspose; - gemm_desc.Alpha = options->alpha; - gemm_desc.Beta = options->beta; - - Node operator_node = - graph_desc_builder_->CreateOperatorNode(DML_OPERATOR_GEMM, &gemm_desc); + Node operator_node; + + if (operator_type == OperatorType::kMatmulInteger) { + auto* a_zero_point_node = node_output_map_[a_index_zero_point].get(); + auto* b_zero_point_node = node_output_map_[b_index_zero_point].get(); + TensorDesc& a_zero_point_tensor_desc = a_zero_point_node->GetTensorDesc(); + TensorDesc& b_zero_point_tensor_desc = b_zero_point_node->GetTensorDesc(); + + input_nodes.insert(input_nodes.begin() + 1, a_zero_point_node); + input_nodes.push_back(b_zero_point_node); + + DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_DESC gemm_desc = {}; + gemm_desc.ATensor = a_tensor_desc.Get(); + gemm_desc.AZeroPointTensor = a_zero_point_tensor_desc.Get(); + gemm_desc.BTensor = b_tensor_desc.Get(); + gemm_desc.BZeroPointTensor = b_zero_point_tensor_desc.Get(); + gemm_desc.OutputTensor = output_tensor_desc.Get(); + + operator_node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_MATRIX_MULTIPLY_INTEGER, &gemm_desc); + } else { + DML_MATRIX_TRANSFORM a_transpose = options->a_transpose + ? DML_MATRIX_TRANSFORM_TRANSPOSE + : DML_MATRIX_TRANSFORM_NONE; + DML_MATRIX_TRANSFORM b_transpose = options->b_transpose + ? DML_MATRIX_TRANSFORM_TRANSPOSE + : DML_MATRIX_TRANSFORM_NONE; + DML_GEMM_OPERATOR_DESC gemm_desc = {}; + gemm_desc.ATensor = a_tensor_desc.Get(); + gemm_desc.BTensor = b_tensor_desc.Get(); + gemm_desc.CTensor = c_tensor_desc.Get(); + gemm_desc.OutputTensor = output_tensor_desc.Get(); + gemm_desc.TransA = a_transpose; + gemm_desc.TransB = b_transpose; + gemm_desc.Alpha = options->alpha; + gemm_desc.Beta = options->beta; + + operator_node = + graph_desc_builder_->CreateOperatorNode(DML_OPERATOR_GEMM, &gemm_desc); + } graph_desc_builder_->Connect(std::move(input_nodes), {operator_node}); - DCHECK_LT(output_dims.size(), output_expand_dims.size()); - TensorDesc output_tensor(b_tensor_desc.GetDataType(), output_dims); node_output_map_[output_index] = graph_desc_builder_->CreateNodeOutput( - operator_node, 0, std::move(output_tensor)); + operator_node, 0, std::move(output_tensor_desc)); return; } -void GraphDMLImpl::AddPool2d(UINT64 input_index, +void GraphDMLImpl::AddPool2d(OperatorType operator_type, + UINT64 input_index, Pool2dOptionsPtr options, - Pool2dType type, OperandDescriptorPtr output_desc, UINT64 output_index) { // TODO: return directly if BuildResult has error message. @@ -699,8 +1037,8 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, auto& input_node_desc = input_node->GetTensorDesc(); auto input_dims = input_node_desc.GetDimensions(); auto output_dims = output_desc->dimensions; - std::vector input_nchw_dims = input_dims, - output_nchw_dims = output_dims; + std::vector input_nchw_dims = input_dims; + std::vector output_nchw_dims = output_dims; DML_TENSOR_DESC* input_tensor_desc = input_node_desc.Get(); TensorDesc nhwc_input_tensor; @@ -736,7 +1074,8 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, TensorDesc output_tensor(input_node_desc.GetDataType(), output_nchw_dims); Node operator_node; - if (type == Pool2dType::kAveragePool2d) { + + if (operator_type == OperatorType::kAveragePool2d) { if (dilations[0] != 1 || dilations[1] != 1) { DAWN_INTERNAL_ERROR("The dilations of average pool2d are not supported."); } @@ -751,7 +1090,8 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, dml_desc.IncludePadding = false; operator_node = graph_desc_builder_->CreateOperatorNode( DML_OPERATOR_AVERAGE_POOLING, &dml_desc); - } else if (type == Pool2dType::kL2Pool2d) { + + } else if (operator_type == OperatorType::kL2Pool2d) { if (dilations[0] != 1 || dilations[1] != 1) { DAWN_INTERNAL_ERROR("The dilations of L2 pool2d are not supported."); } @@ -767,7 +1107,8 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, dml_desc.P = 2; operator_node = graph_desc_builder_->CreateOperatorNode( DML_OPERATOR_LP_POOLING, &dml_desc); - } else if (type == Pool2dType::kMaxPool2d) { + + } else if (operator_type == OperatorType::kMaxPool2d) { if (dilations[0] != 1 || dilations[1] != 1) { for (size_t i = 0; i < windowSizes.size(); ++i) { uint32_t paddedInputSize = @@ -793,8 +1134,9 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, desc.Dilations = dilations.data(); operator_node = graph_desc_builder_->CreateOperatorNode( DML_OPERATOR_MAX_POOLING2, &desc); + } else { - DAWN_INTERNAL_ERROR("This pool2d type is not supported."); + DAWN_INTERNAL_ERROR("This pool2d operator type is not supported."); } graph_desc_builder_->Connect({input_node}, operator_node); auto output_node = graph_desc_builder_->CreateNodeOutput( @@ -813,20 +1155,8 @@ void GraphDMLImpl::AddPool2d(UINT64 input_index, void GraphDMLImpl::AddRelu(UINT64 input_index, OperandDescriptorPtr output_desc, UINT64 output_index) { - // TODO: return directly if BuildResult has error message. - DCHECK(node_output_map_.find(input_index) != node_output_map_.end()); - - auto* input_node = node_output_map_[input_index].get(); - auto& input_tensor_desc = input_node->GetTensorDesc(); - Node node; - CREATE_UNARY_OPERATOR(ACTIVATION_RELU, input_tensor_desc.Get(), node); - graph_desc_builder_->Connect({input_node}, {node}); - TensorDesc output_tensor_desc(input_tensor_desc.GetDataType(), - input_tensor_desc.GetDimensions()); - auto node_output = graph_desc_builder_->CreateNodeOutput( - node, 0, std::move(output_tensor_desc)); - node_output_map_[output_index] = std::move(node_output); - return; + return AddElementWiseUnary(OperatorType::kRelu, input_index, + std::move(output_desc), output_index); } void GraphDMLImpl::AddSoftmax(UINT64 input_index, @@ -838,205 +1168,785 @@ void GraphDMLImpl::AddSoftmax(UINT64 input_index, auto* input_node = node_output_map_[input_index].get(); auto& input_tensor_desc = input_node->GetTensorDesc(); Node node; - CREATE_UNARY_OPERATOR(ACTIVATION_SOFTMAX, input_tensor_desc.Get(), node); + CREATE_UNARY_OPERATOR(ACTIVATION_SOFTMAX, input_tensor_desc.Get(), input_tensor_desc.Get(), node); graph_desc_builder_->Connect({input_node}, {node}); - TensorDesc output_tensor_desc(input_tensor_desc.GetDataType(), - input_tensor_desc.GetDimensions()); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), + output_desc->dimensions); auto node_output = graph_desc_builder_->CreateNodeOutput( node, 0, std::move(output_tensor_desc)); node_output_map_[output_index] = std::move(node_output); return; } -void GraphDMLImpl::AddOutput(const std::string& name, UINT64 index) { - DCHECK(node_output_map_.find(index) != node_output_map_.end()); - auto* output_node = node_output_map_[index].get(); - DCHECK(output_node != nullptr); - - // Append identity to avoid directly using graph input as output, and - // avoid lack of considering the impacts of strides if there are. - auto node = output_node->GetNode(); - if (node.type == NodeType::kInput || node.type == NodeType::kConstant || - output_node->GetTensorDesc().GetStrides()) { - auto& input_tensor = output_node->GetTensorDesc(); +void GraphDMLImpl::AddElementWiseIf(UINT64 condition_index, + UINT64 true_value_index, + UINT64 false_value_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(condition_index)); + DCHECK(node_output_map_.contains(true_value_index)); + DCHECK(node_output_map_.contains(false_value_index)); + + NodeOutput* condition_node = node_output_map_[condition_index].get(); + NodeOutput* true_value_node = node_output_map_[true_value_index].get(); + NodeOutput* false_value_node = node_output_map_[false_value_index].get(); + + // Broadcast each of the inputs to the output. + auto output_dimensions = output_desc->dimensions; + TensorDesc condition_broadcasted_tensor = + GetBroadcastedTensorDesc(condition_node, output_dimensions); + TensorDesc true_value_broadcasted_tensor = + GetBroadcastedTensorDesc(true_value_node, output_dimensions); + TensorDesc false_value_broadcasted_tensor = + GetBroadcastedTensorDesc(false_value_node, output_dimensions); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), + output_dimensions); + + DML_ELEMENT_WISE_IF_OPERATOR_DESC operator_desc = {}; + operator_desc.ConditionTensor = condition_broadcasted_tensor.Get(); + operator_desc.ATensor = true_value_broadcasted_tensor.Get(); + operator_desc.BTensor = false_value_broadcasted_tensor.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ELEMENT_WISE_IF, &operator_desc); + + graph_desc_builder_->Connect( + {condition_node, true_value_node, false_value_node}, {node}); - TensorDesc output_tensor(input_tensor.GetDataType(), - input_tensor.GetDimensions()); - APPEND_IDENTITY(input_tensor.Get(), output_tensor.Get(), node); - graph_desc_builder_->Connect({output_node}, {node}); - std::unique_ptr identity_output_node = - graph_desc_builder_->CreateNodeOutput(node, 0, - std::move(output_tensor)); - graph_desc_builder_->AddOutputEdge(identity_output_node.get(), name); - } else { - graph_desc_builder_->AddOutputEdge(output_node, name); - } - return; + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); } -void GraphDMLImpl::Build(ModelInfoPtr model_info, BuildCallback callback) { - // Add Input - for (auto& input : model_info->inputs) { - auto& operand_desc = model_info->operands[input->index]; - AddInput(std::move(input->name), std::move(operand_desc), input->index); +void GraphDMLImpl::AddArgMinMax(OperatorType operator_type, + UINT64 input_index, + uint32_t axis, + bool select_last_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& input_dimensions = input_tensor_desc.GetDimensions(); + DCHECK(node_output_map_.contains(input_index)); + + // Determine output sizes. Ignore output_desc->dimensions for the dimensions, + // since DirectML expects the output dimensions to have the same rank as the + // input, and output_desc->dimensions may have removed dimensions if + // keepDimensions was false. + std::vector output_dimensions = input_dimensions; + DCHECK(axis < output_dimensions.size()); + output_dimensions[axis] = 1u; + auto output_data_type = GetTensorDataType(output_desc->data_type); + TensorDesc output_tensor_desc(output_data_type, output_dimensions); + TensorDesc original_output_tensor_desc(output_data_type, + output_desc->dimensions); + + // DML accepts multiple axes. So pass the single index along. + std::array axes = {axis}; + + // Note DML_ARGMIN_OPERATOR_DESC and DML_ARGMAX_OPERATOR_DESC are + // identical in structure layout. So we can use for either. + DML_ARGMIN_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.AxisCount = 1u; + operator_desc.Axes = axes.data(); + operator_desc.AxisDirection = + static_cast(select_last_index); + + DML_OPERATOR_TYPE dml_operator_type = DML_OPERATOR_INVALID; + switch (operator_type) { + case OperatorType::kArgMin: + dml_operator_type = DML_OPERATOR_ARGMIN; + break; + case OperatorType::kArgMax: + dml_operator_type = DML_OPERATOR_ARGMAX; + break; + default: + NOTREACHED(); } + Node node = graph_desc_builder_->CreateOperatorNode(dml_operator_type, + &operator_desc); - // Add Constant - std::unique_ptr uploader = - std::make_unique(execution_context_.get()); - ComPtr constants_resource = nullptr; - auto constants_info = std::move(model_info->constants); - if (constants_info.get() != nullptr) { - for (auto& [index, _] : constants_info->memory_info) { - auto& operand_desc = model_info->operands[index]; - AddConstant(std::move(operand_desc), index); - } - // Upload the data to GPU so that the constant data are not saved as member - // variable. - base::ReadOnlySharedMemoryRegion& shared_memory_region = - constants_info->shared_memory; - size_t constants_byte_length = shared_memory_region.GetSize(); - ExecutionResources* execution_resources = - execution_context_->GetExecutionResources(); - constants_resource = execution_resources->Allocate(constants_byte_length); - uploader->UploadConstants(constants_resource->GetResource(), - constants_info); + graph_desc_builder_->Connect({input_node}, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(original_output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddCast(UINT64 input_index, + OperandType data_type, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& output_dimensions = output_desc->dimensions; + auto output_data_type = GetTensorDataType(data_type); + TensorDesc output_tensor_desc(output_data_type, output_dimensions); + + DML_CAST_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_CAST, &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddConcat(base::span input_indices, + uint32_t axis, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + + std::vector input_nodes; + std::vector input_tensor_descs; + input_nodes.reserve(input_indices.size()); + input_tensor_descs.reserve(input_indices.size()); + + for (uint64_t input_index : input_indices) + { + DCHECK(node_output_map_.contains(input_index)); + NodeOutput* input_node = node_output_map_[input_index].get(); + input_nodes.push_back(input_node); + auto& input_tensor_desc = input_node->GetTensorDesc(); + input_tensor_descs.push_back(*input_tensor_desc.Get()); } - // Add operations - for (auto& operation : model_info->operations) { - switch (operation->which()) { - case OperationInfo::Tag::kClamp: { - auto& clamp = operation->get_clamp(); - AddClamp(clamp->input_index, std::move(clamp->options), - clamp->output_index); - break; - } - case OperationInfo::Tag::kConv2d: { - auto& conv2d = operation->get_conv2d(); - auto& output_operand = model_info->operands[conv2d->output_index]; - AddConv2d(conv2d->input_index, conv2d->filter_index, - std::move(conv2d->options), std::move(output_operand), - conv2d->output_index); - break; - } - case OperationInfo::Tag::kElementWiseBinary: { - auto& binary = operation->get_element_wise_binary(); - auto& output_operand = model_info->operands[binary->output_index]; - AddElementWiseBinary(binary->a_index, binary->b_index, binary->type, - std::move(output_operand), binary->output_index); - break; - } - case OperationInfo::Tag::kGemm: { - auto& gemm = operation->get_gemm(); - auto& output_operand = model_info->operands[gemm->output_index]; - AddGemm(gemm->a_index, gemm->b_index, std::move(gemm->options), - std::move(output_operand), gemm->output_index); - break; - } - case OperationInfo::Tag::kPool2d: { - auto& pool2d = operation->get_pool2d(); - auto& output_operand = model_info->operands[pool2d->output_index]; - AddPool2d(pool2d->input_index, std::move(pool2d->options), pool2d->type, - std::move(output_operand), pool2d->output_index); - break; - } - case OperationInfo::Tag::kRelu: { - auto& relu = operation->get_relu(); - auto& output_operand = model_info->operands[relu->output_index]; - AddRelu(relu->input_index, std::move(output_operand), - relu->output_index); - break; - } - case OperationInfo::Tag::kReshape: { - auto& reshape = operation->get_reshape(); - auto& output_operand = model_info->operands[reshape->output_index]; - AddReshape(reshape->input_index, std::move(output_operand), - reshape->output_index); - break; - } - case OperationInfo::Tag::kSoftmax: { - auto& softmax = operation->get_softmax(); - auto& output_operand = model_info->operands[softmax->output_index]; - AddSoftmax(softmax->input_index, std::move(output_operand), - softmax->output_index); - break; - } - default: - NOTREACHED(); - } + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + + DML_JOIN_OPERATOR_DESC operator_desc = {}; + operator_desc.InputCount = static_cast(input_tensor_descs.size()); + operator_desc.InputTensors = input_tensor_descs.data(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Axis = axis; + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_JOIN, &operator_desc); + + graph_desc_builder_->Connect(input_nodes, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddSlice(UINT64 input_index, + base::span starts, + base::span sizes, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& output_dimensions = output_desc->dimensions; + size_t output_rank = output_dimensions.size(); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + + // WebNN v1 only supports steps of 1. + std::vector strides(output_rank, 1u); + CHECK(starts.size() == output_rank); + CHECK(sizes.size() == output_rank); + + DML_SLICE_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.DimensionCount = output_rank; + operator_desc.Offsets = starts.data(); + operator_desc.Sizes = sizes.data(); + operator_desc.Strides = strides.data(); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_SLICE, &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddSplit(UINT64 input_index, + uint32_t axis, + base::flat_map& operands, + base::span output_indices + ) { + std::vector output_nodes; + std::vector output_tensor_descs; + std::vector output_tensor_dml_descs; + output_nodes.reserve(output_indices.size()); + output_tensor_descs.reserve(output_indices.size()); + output_tensor_dml_descs.reserve(output_indices.size()); + + DCHECK(node_output_map_.contains(input_index)); + NodeOutput* input_node = node_output_map_[input_index].get(); + TensorDesc& input_tensor_desc = input_node->GetTensorDesc(); + + for (uint64_t output_index : output_indices) + { + OperandDescriptorPtr const& output_desc = operands[output_index]; + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + output_tensor_descs.push_back(std::move(output_tensor_desc)); + output_tensor_dml_descs.push_back(*output_tensor_descs.back().Get()); } - // Add Output with named operands. - for (auto& output : model_info->outputs) { - AddOutput(std::move(output->name), output->index); + DML_SPLIT_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputCount = static_cast(output_tensor_dml_descs.size()); + operator_desc.OutputTensors = output_tensor_dml_descs.data(); + operator_desc.Axis = axis; + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_SPLIT, &operator_desc); + + graph_desc_builder_->Connect({input_node}, node); + + for (uint64_t i = 0, output_count = output_indices.size(); i < output_count; ++i) + { + std::unique_ptr node_output = graph_desc_builder_->CreateNodeOutput( + node, i, std::move(output_tensor_descs[i])); + uint64_t output_index = output_indices[i]; + node_output_map_[output_index] = std::move(node_output); } +} - // Finish the graph build. - mCompiledOperator = graph_desc_builder_->Compile(DML_EXECUTION_FLAG_NONE); +void GraphDMLImpl::AddExpand(UINT64 input_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); - auto input_nodes = graph_desc_builder_->GetInputNodes(); - std::vector input_buffer_binding(input_nodes.size()); - for (size_t i = 0; i < input_nodes.size(); ++i) { - auto input = input_nodes[i]; - if (input.type == NodeType::kConstant) { - input_buffer_binding[i].Buffer = constants_resource->GetResource(); - auto& memory_info = constants_info->memory_info[input.object_id]; - input_buffer_binding[i].Offset = memory_info->byte_offset; - input_buffer_binding[i].SizeInBytes = memory_info->byte_length; - } + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + TensorDesc input_tensor_desc = + GetBroadcastedTensorDesc(input_node, output_dimensions); + + Node node; + CREATE_UNARY_OPERATOR(ELEMENT_WISE_IDENTITY, input_tensor_desc.Get(), output_tensor_desc.Get(), node); + + graph_desc_builder_->Connect({input_node}, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddGather(UINT64 input_index, + UINT64 indices_index, + uint32_t axis, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + NodeOutput* indices_node = node_output_map_[indices_index].get(); + auto& indices_tensor_desc = indices_node->GetTensorDesc(); + auto& output_dimensions = output_desc->dimensions; + + size_t maximum_rank = std::max({input_tensor_desc.GetDimensions().size(), + indices_tensor_desc.GetDimensions().size(), + output_dimensions.size()}); + + // Expand all tensor ranks to match ranks (which DML validation requires). + TensorDesc input_expanded_desc(input_tensor_desc); + TensorDesc indices_expanded_desc(indices_tensor_desc); + TensorDesc output_expanded_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + TensorDesc original_output_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + input_expanded_desc.EnsureMinimumRank(maximum_rank, TensorDesc::Alignment::kTrailing); + indices_expanded_desc.EnsureMinimumRank(maximum_rank, TensorDesc::Alignment::kTrailing); + output_expanded_desc.EnsureMinimumRank(maximum_rank, TensorDesc::Alignment::kTrailing); + + DML_GATHER_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_expanded_desc.Get(); + operator_desc.IndicesTensor = indices_expanded_desc.Get(); + operator_desc.OutputTensor = output_expanded_desc.Get(); + operator_desc.IndexDimensions = indices_tensor_desc.GetDimensions().size(); + operator_desc.Axis = axis; + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_GATHER, &operator_desc); + + graph_desc_builder_->Connect({input_node, indices_node}, {node}); + + auto node_output = + graph_desc_builder_->CreateNodeOutput(node, 0, std::move(original_output_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddInstanceNormalization(uint64_t input_index, + uint64_t scale_index, + uint64_t bias_index, + float epsilon, + InputOperandLayout operand_layout, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + NodeOutput* input_node = node_output_map_[input_index].get(); + TensorDesc input_tensor_desc = input_node->GetTensorDesc(); + TensorDesc scale_tensor_desc; + TensorDesc bias_tensor_desc; + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + TensorDesc original_output_desc(input_tensor_desc.GetDataType(), output_dimensions); + + // DirectML expects NCHW. So permute the dimension's sizes and strides accordingly + // if the input is anything else (e.g. NHWC). + std::array tensor_dimensions_permutation = getLayoutPermutationToNchw(operand_layout); + input_tensor_desc.PermuteDimensions(tensor_dimensions_permutation, TensorDesc::Alignment::kTrailing); + output_tensor_desc.PermuteDimensions(tensor_dimensions_permutation, TensorDesc::Alignment::kTrailing); + + // DirectML expects the channel dimension to be at NCHW. + // So move the C dimension in 1D from XXXC to XCXX. + constexpr std::array scale_bias_dimensions_permutation = {0,3,0,0}; + + std::vector input_nodes; + input_nodes.reserve(3); + input_nodes.push_back(input_node); + + if (scale_index != std::numeric_limits::max()) { + NodeOutput* scale_node = node_output_map_[scale_index].get(); + scale_tensor_desc = scale_node->GetTensorDesc(); + scale_tensor_desc.PermuteDimensions(scale_bias_dimensions_permutation, TensorDesc::Alignment::kTrailing); + input_nodes.push_back(scale_node); + } + if (bias_index != std::numeric_limits::max()) { + NodeOutput* bias_node = node_output_map_[bias_index].get(); + bias_tensor_desc = bias_node->GetTensorDesc(); + bias_tensor_desc.PermuteDimensions(scale_bias_dimensions_permutation, TensorDesc::Alignment::kTrailing); + input_nodes.push_back(bias_node); } - DML_BUFFER_ARRAY_BINDING input_buffer_array_binding = {}; - input_buffer_array_binding.BindingCount = input_buffer_binding.size(); - input_buffer_array_binding.Bindings = input_buffer_binding.data(); - DML_BINDING_DESC input_binding_desc{DML_BINDING_TYPE_BUFFER_ARRAY, - &input_buffer_array_binding}; + std::array axes = {2, 3}; + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.ScaleTensor = scale_tensor_desc.Get(); + operator_desc.BiasTensor = bias_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.AxisCount = static_cast(axes.size()); + operator_desc.Axes = axes.data(); + operator_desc.NormalizeVariance = true; + operator_desc.Epsilon = epsilon; + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operator_desc); - execution_context_->InitializeGraph(this, mCompiledOperator.Get(), - input_binding_desc); + graph_desc_builder_->Connect(input_nodes, {node}); - execution_context_->Flush(); - execution_context_->WaitForSignal(); - execution_context_->ReleaseCompletedResources(); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(original_output_desc)); + node_output_map_[output_index] = std::move(node_output); +} - auto& named_outputs = graph_desc_builder_->GetNamedOutputs(); - HRESULT hr = output_resource_readback_->InitializeResource(named_outputs); - if (FAILED(hr)) { - std::move(callback).Run(BuildResult::kUnknownError); - return; +void GraphDMLImpl::AddMeanVarianceNormalization(uint64_t input_index, + uint64_t mean_index, + uint64_t variance_index, + uint64_t scale_index, + uint64_t bias_index, + float epsilon, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + NodeOutput* input_node = node_output_map_[input_index].get(); + TensorDesc input_tensor_desc = input_node->GetTensorDesc(); + TensorDesc mean_tensor_desc; + TensorDesc variance_tensor_desc; + TensorDesc scale_tensor_desc; + TensorDesc bias_tensor_desc; + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + + std::vector input_nodes; + input_nodes.reserve(5); + input_nodes.push_back(input_node); + + bool has_scale = scale_index != std::numeric_limits::max(); + bool has_bias = bias_index != std::numeric_limits::max(); + bool has_mean = mean_index != std::numeric_limits::max(); + bool has_variance = variance_index != std::numeric_limits::max(); + + if (has_scale) { + NodeOutput* scale_node = node_output_map_[scale_index].get(); + scale_tensor_desc = scale_node->GetTensorDesc(); + input_nodes.push_back(scale_node); + } + if (has_bias) { + NodeOutput* bias_node = node_output_map_[bias_index].get(); + bias_tensor_desc = bias_node->GetTensorDesc(); + input_nodes.push_back(bias_node); + } + if (has_mean) { + NodeOutput* mean_node = node_output_map_[mean_index].get(); + mean_tensor_desc = mean_node->GetTensorDesc(); + input_nodes.push_back(mean_node); + } + if (has_variance) { + NodeOutput* variance_node = node_output_map_[variance_index].get(); + variance_tensor_desc = variance_node->GetTensorDesc(); + input_nodes.push_back(variance_node); } - std::move(callback).Run(BuildResult::kOk); + Node node; + if (has_mean || has_variance) { + // The precomputed mean and variance are already supplied. + // So there's no need to specify axes or whether to normalize variance. + DML_BATCH_NORMALIZATION_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.MeanTensor = mean_tensor_desc.Get(); + operator_desc.VarianceTensor = variance_tensor_desc.Get(); + operator_desc.ScaleTensor = scale_tensor_desc.Get(); + operator_desc.BiasTensor = bias_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.Spatial = true; + operator_desc.Epsilon = epsilon; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_BATCH_NORMALIZATION, &operator_desc); + + } else { + // Compute the mean and variance from the input. + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.ScaleTensor = scale_tensor_desc.Get(); + operator_desc.BiasTensor = bias_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.AxisCount = static_cast(axes.size()); + operator_desc.Axes = axes.data(); + operator_desc.NormalizeVariance = true; + operator_desc.Epsilon = epsilon; + node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &operator_desc); + } + + graph_desc_builder_->Connect(input_nodes, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddPad(UINT64 input_index, + PaddingMode padding_mode, + float value, + base::span beginning_padding, + base::span ending_padding, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + DCHECK(beginning_padding.size() == ending_padding.size()); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& output_dimensions = output_desc->dimensions; + + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), + output_dimensions); + + DML_PADDING_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.PaddingMode = MapPaddingModeToDml(padding_mode); + operator_desc.PaddingValue = value; + operator_desc.DimensionCount = + static_cast(beginning_padding.size()); + operator_desc.StartPadding = beginning_padding.data(); + operator_desc.EndPadding = ending_padding.data(); + Node node = graph_desc_builder_->CreateOperatorNode(DML_OPERATOR_PADDING, + &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddFillSequence(float start, + float step, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + auto& output_dimensions = output_desc->dimensions; + DML_TENSOR_DATA_TYPE dml_data_type = + GetTensorDataType(output_desc->data_type); + TensorDesc output_tensor_desc(dml_data_type, output_dimensions); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC operator_desc = {}; + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.ValueDataType = dml_data_type; + operator_desc.ValueStart = GetScalarUnion(operator_desc.ValueDataType, start); + operator_desc.ValueDelta = GetScalarUnion(operator_desc.ValueDataType, step); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_FILL_VALUE_SEQUENCE, &operator_desc); + + graph_desc_builder_->Connect({}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddReduce(UINT64 input_index, + OperatorType operator_type, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& input_dimensions = input_tensor_desc.GetDimensions(); + auto& original_output_dimensions = output_desc->dimensions; + + // Determine output sizes. Ignore output_desc->dimensions for the dimensions, + // since DirectML expects the output dimensions to have the same rank as the + // input, and output_desc->dimensions may have removed dimensions if + // keepDimensions was false. + std::vector output_dimensions = input_dimensions; + for (uint32_t axis : axes) + { + DCHECK(axis < output_dimensions.size()); + output_dimensions[axis] = 1u; + } + + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), + output_dimensions); + TensorDesc original_output_tensor_desc( + GetTensorDataType(output_desc->data_type), original_output_dimensions); + + DML_REDUCE_OPERATOR_DESC operator_desc = {}; + operator_desc.Function = MapOperatorTypeToReductionFuntion(operator_type); + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.AxisCount = static_cast(axes.size()); + operator_desc.Axes = axes.data(); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_REDUCE, &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(original_output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddResample2d(UINT64 input_index, + ml::webnn::mojom::InterpolationMode interpolation_mode, + base::span scales, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + DCHECK(scales.size() == axes.size()); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& output_dimensions = output_desc->dimensions; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + + std::vector full_scales(output_dimensions.size(), 1u); + for (size_t i = 0; i < axes.size(); ++i) + { + auto axis = axes[i]; + DCHECK(axis < full_scales.size()); // The JS layer and mojom layer validated it. + full_scales[axis] = scales[i]; + } + + static_assert(uint32_t(ml::webnn::mojom::InterpolationMode::kMaxValue) == + 1); // Update assert. + static_assert( + uint32_t(DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR) == + uint32_t(ml::webnn::mojom::InterpolationMode::kNearestNeighbor)); + static_assert(uint32_t(DML_INTERPOLATION_MODE_LINEAR) == + uint32_t(ml::webnn::mojom::InterpolationMode::kLinear)); + + DML_RESAMPLE_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.InterpolationMode = static_cast(interpolation_mode); + operator_desc.ScaleCount = static_cast(full_scales.size()); + operator_desc.Scales = full_scales.data(); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_RESAMPLE, &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddTranspose(UINT64 input_index, + base::span permutation, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + auto& input_tensor_desc = input_node->GetTensorDesc(); + auto& input_dimensions = input_tensor_desc.GetDimensions(); + auto& output_dimensions = output_desc->dimensions; + DCHECK(input_dimensions.size() == output_dimensions.size()); + input_tensor_desc.EnsureStridesExist(); + + auto rearranged_input_strides = + transposeStrides(*input_tensor_desc.GetStrides(), permutation); + + // Construct a new input tensor description based on the outputs's rearranged + // dimensions, which is identical except for the remapped strides. Then both + // input and output have the same sizes, just different memory mappings when + // reading elements. + TensorDesc remapped_input_tensor_desc( + input_tensor_desc.GetDataType(), input_tensor_desc.GetFlags(), + output_dimensions, rearranged_input_strides); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), + output_dimensions); + + Node node; + CREATE_UNARY_OPERATOR(ELEMENT_WISE_IDENTITY, remapped_input_tensor_desc.Get(), + output_tensor_desc.Get(), node); + + graph_desc_builder_->Connect({input_node}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddTriangularMatrix(UINT64 input_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + // TODO: +} + +void GraphDMLImpl::AddDequantizeLinear(UINT64 input_index, + UINT64 scale_index, + UINT64 zero_point_index, + OperandDescriptorPtr output_desc, + UINT64 output_index) { + DCHECK(node_output_map_.contains(input_index)); + DCHECK(node_output_map_.contains(scale_index)); + DCHECK(node_output_map_.contains(zero_point_index)); + + NodeOutput* input_node = node_output_map_[input_index].get(); + NodeOutput* scale_node = node_output_map_[scale_index].get(); + NodeOutput* zero_point_node = node_output_map_[zero_point_index].get(); + + auto& output_dimensions = output_desc->dimensions; + auto& input_tensor_desc = input_node->GetTensorDesc(); + TensorDesc scale_tensor_desc = GetBroadcastedTensorDesc(scale_node, output_dimensions); + TensorDesc zero_point_tensor_desc = GetBroadcastedTensorDesc(zero_point_node, output_dimensions); + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_dimensions); + + DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.ScaleTensor = scale_tensor_desc.Get(); + operator_desc.ZeroPointTensor = zero_point_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &operator_desc); + + graph_desc_builder_->Connect({input_node, scale_node, zero_point_node}, {node}); + auto node_output = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_index] = std::move(node_output); +} + +void GraphDMLImpl::AddDynamicQuantizeLinear( + UINT64 input_index, + base::flat_map& operands, + //--base::span output_indices + UINT64 output_index, + UINT64 output_scale_index, + UINT64 output_zero_point_index) { + DCHECK(node_output_map_.contains(input_index)); + NodeOutput* input_node = node_output_map_[input_index].get(); + TensorDesc& input_tensor_desc = input_node->GetTensorDesc(); + + OperandDescriptorPtr const& output_desc = operands[output_index]; + OperandDescriptorPtr const& output_scale_desc = operands[output_scale_index]; + OperandDescriptorPtr const& output_zero_point_desc = operands[output_zero_point_index]; + TensorDesc output_tensor_desc(GetTensorDataType(output_desc->data_type), output_desc->dimensions); + TensorDesc output_scale_tensor_desc(GetTensorDataType(output_scale_desc->data_type), output_scale_desc->dimensions); + TensorDesc output_zero_point_tensor_desc(GetTensorDataType(output_zero_point_desc->data_type), output_zero_point_desc->dimensions); + + // Ensure output scalars have the same rank as main output. + output_scale_tensor_desc.EnsureMinimumRank(output_tensor_desc.GetDimensions().size(), TensorDesc::Alignment::kTrailing); + output_zero_point_tensor_desc.EnsureMinimumRank(output_tensor_desc.GetDimensions().size(), TensorDesc::Alignment::kTrailing); + + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC operator_desc = {}; + operator_desc.InputTensor = input_tensor_desc.Get(); + operator_desc.OutputTensor = output_tensor_desc.Get(); + operator_desc.OutputScaleTensor = output_scale_tensor_desc.Get(); + operator_desc.OutputZeroPointTensor = output_zero_point_tensor_desc.Get(); + + Node node = graph_desc_builder_->CreateOperatorNode( + DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &operator_desc); + + graph_desc_builder_->Connect({input_node}, {node}); + node_output_map_[output_index] = graph_desc_builder_->CreateNodeOutput( + node, 0, std::move(output_tensor_desc)); + node_output_map_[output_scale_index] = graph_desc_builder_->CreateNodeOutput( + node, 1, std::move(output_scale_tensor_desc)); + node_output_map_[output_zero_point_index] = graph_desc_builder_->CreateNodeOutput( + node, 2, std::move(output_zero_point_tensor_desc)); +} + +void GraphDMLImpl::AddOutput(const std::string& name, UINT64 index) { + DCHECK(node_output_map_.find(index) != node_output_map_.end()); + auto* output_node = node_output_map_[index].get(); + DCHECK(output_node != nullptr); + + // Append identity to avoid directly using graph input as output, and + // avoid lack of considering the impacts of strides if there are. + auto node = output_node->GetNode(); + if (node.type == NodeType::kInput || node.type == NodeType::kConstant || + output_node->GetTensorDesc().GetStrides()) { + auto& input_tensor = output_node->GetTensorDesc(); + + TensorDesc output_tensor(input_tensor.GetDataType(), + input_tensor.GetDimensions()); + + CREATE_UNARY_OPERATOR(ELEMENT_WISE_IDENTITY, input_tensor.Get(), output_tensor.Get(), node); + + graph_desc_builder_->Connect({output_node}, {node}); + std::unique_ptr identity_output_node = + graph_desc_builder_->CreateNodeOutput(node, 0, + std::move(output_tensor)); + graph_desc_builder_->AddOutputEdge(identity_output_node.get(), name); + } else { + graph_desc_builder_->AddOutputEdge(output_node, name); + } return; } -bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { - // Add Input +void GraphDMLImpl::Build(ModelInfoPtr model_info, BuildCallback callback) { for (auto& input : model_info->inputs) { auto& operand_desc = model_info->operands[input->index]; AddInput(std::move(input->name), std::move(operand_desc), input->index); } // Add Constant - std::unique_ptr uploader = + constant_resource_uploader_ = std::make_unique(execution_context_.get()); - ComPtr constants_resource = nullptr; + ComPtr constant_resource = nullptr; auto constants_info = std::move(model_info->constants); + if (constants_info.get() != nullptr) { for (auto& [index, _] : constants_info->memory_info) { auto& operand_desc = model_info->operands[index]; AddConstant(std::move(operand_desc), index); } - // Upload the data to GPU so that the constant data are not saved as member - // variable. base::ReadOnlySharedMemoryRegion& shared_memory_region = constants_info->shared_memory; size_t constants_byte_length = shared_memory_region.GetSize(); ExecutionResources* execution_resources = execution_context_->GetExecutionResources(); - constants_resource = execution_resources->Allocate(constants_byte_length); - uploader->UploadConstants(constants_resource->GetResource(), + constant_resource = execution_resources->Allocate(constants_byte_length); + constant_resource_uploader_->UploadConstants(constant_resource->GetResource(), constants_info); } @@ -1052,29 +1962,49 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { case OperationInfo::Tag::kConv2d: { auto& conv2d = operation->get_conv2d(); auto& output_operand = model_info->operands[conv2d->output_index]; - AddConv2d(conv2d->input_index, conv2d->filter_index, - std::move(conv2d->options), std::move(output_operand), - conv2d->output_index); + AddConv2d(conv2d->operator_type, conv2d->input_index, + conv2d->filter_index, conv2d->input_zero_point_index, + conv2d->filter_zero_point_index, std::move(conv2d->options), + std::move(output_operand), conv2d->output_index); + break; + } + case OperationInfo::Tag::kElementWiseUnary: { + auto& mojom_operator = operation->get_element_wise_unary(); + auto& output_operand = model_info->operands[mojom_operator->output_index]; + AddElementWiseUnary(mojom_operator->operator_type, mojom_operator->input_index, + std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kElementWiseUnaryTwoParameter: { + auto& mojom_operator = operation->get_element_wise_unary_two_parameter(); + auto& output_operand = model_info->operands[mojom_operator->output_index]; + AddElementWiseUnaryTwoParameter(mojom_operator->operator_type, mojom_operator->input_index, + mojom_operator->first_parameter, mojom_operator->second_parameter, + std::move(output_operand), + mojom_operator->output_index); break; } case OperationInfo::Tag::kElementWiseBinary: { auto& binary = operation->get_element_wise_binary(); auto& output_operand = model_info->operands[binary->output_index]; - AddElementWiseBinary(binary->a_index, binary->b_index, binary->type, + AddElementWiseBinary(binary->operator_type, binary->a_index, binary->b_index, std::move(output_operand), binary->output_index); break; } case OperationInfo::Tag::kGemm: { + // GEMM and MatMul. auto& gemm = operation->get_gemm(); auto& output_operand = model_info->operands[gemm->output_index]; - AddGemm(gemm->a_index, gemm->b_index, std::move(gemm->options), + AddGemm(gemm->operator_type, gemm->a_index, gemm->b_index, gemm->a_zero_point_index, + gemm->b_zero_point_index, std::move(gemm->options), std::move(output_operand), gemm->output_index); break; } case OperationInfo::Tag::kPool2d: { auto& pool2d = operation->get_pool2d(); auto& output_operand = model_info->operands[pool2d->output_index]; - AddPool2d(pool2d->input_index, std::move(pool2d->options), pool2d->type, + AddPool2d(pool2d->operator_type, pool2d->input_index, std::move(pool2d->options), std::move(output_operand), pool2d->output_index); break; } @@ -1099,6 +2029,184 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { softmax->output_index); break; } + case OperationInfo::Tag::kArgMinMax: { + auto& mojom_operator = operation->get_arg_min_max(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddArgMinMax(mojom_operator->operator_type, mojom_operator->input_index, + mojom_operator->axis, + mojom_operator->select_last_index, + std::move(output_operand), mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kCast: { + auto& mojom_operator = operation->get_cast(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddCast(mojom_operator->input_index, mojom_operator->data_type, + std::move(output_operand), mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kConcat: { + auto& mojom_operator = operation->get_concat(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddConcat(mojom_operator->input_indices, mojom_operator->axis, + std::move(output_operand), mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kSlice: { + auto& mojom_operator = operation->get_slice(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddSlice(mojom_operator->input_index, mojom_operator->starts, + mojom_operator->sizes, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kSplit: { + auto& mojom_operator = operation->get_split(); + AddSplit(mojom_operator->input_index, mojom_operator->axis, + model_info->operands, mojom_operator->output_indices); + break; + } + case OperationInfo::Tag::kExpand: { + auto& mojom_operator = operation->get_expand(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddExpand(mojom_operator->input_index, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kGather: { + auto& mojom_operator = operation->get_gather(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddGather(mojom_operator->input_index, mojom_operator->indices_index, + mojom_operator->axis, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kInstanceNormalization: { + auto& mojom_operator = operation->get_instance_normalization(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddInstanceNormalization( + mojom_operator->input_index, mojom_operator->scale_index, + mojom_operator->bias_index, mojom_operator->epsilon, + mojom_operator->layout, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kMeanVarianceNormalization: { + auto& mojom_operator = operation->get_mean_variance_normalization(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddMeanVarianceNormalization( + mojom_operator->input_index, mojom_operator->mean_index, + mojom_operator->variance_index, mojom_operator->scale_index, + mojom_operator->bias_index, mojom_operator->epsilon, + mojom_operator->axes, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kPad: { + auto& mojom_operator = operation->get_pad(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddPad(mojom_operator->input_index, + mojom_operator->mode, + mojom_operator->value, + mojom_operator->beginningPadding, + mojom_operator->endingPadding, + std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kFillSequence: { + auto& mojom_operator = operation->get_fill_sequence(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddFillSequence(mojom_operator->start, mojom_operator->delta, + std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kReduce: { + auto& mojom_operator = operation->get_reduce(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddReduce(mojom_operator->input_index, mojom_operator->operator_type, + mojom_operator->axes, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kResample2d: { + auto& mojom_operator = operation->get_resample2d(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddResample2d(mojom_operator->input_index, + mojom_operator->interpolation_mode, + mojom_operator->scales, mojom_operator->axes, + std::move(output_operand), mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kTranspose: { + auto& mojom_operator = operation->get_transpose(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddTranspose(mojom_operator->input_index, mojom_operator->permutation, + std::move(output_operand), mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kTriangularMatrix: + DCHECK(false); + break; + case OperationInfo::Tag::kElementWiseIf: { + auto& mojom_operator = operation->get_element_wise_if(); + auto& output_operand = + model_info->operands[mojom_operator->output_index]; + AddElementWiseIf( + mojom_operator->condition_index, mojom_operator->true_value_index, + mojom_operator->false_value_index, std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kDequantizeLinear: { + auto& mojom_operator = operation->get_dequantize_linear(); + auto& output_operand = model_info->operands[mojom_operator->output_index]; + AddDequantizeLinear(mojom_operator->input_index, + mojom_operator->scale_index, + mojom_operator->zero_point_index, + std::move(output_operand), + mojom_operator->output_index); + break; + } + case OperationInfo::Tag::kDynamicQuantizeLinear: { + auto& mojom_operator = operation->get_dynamic_quantize_linear(); + AddDynamicQuantizeLinear( + mojom_operator->input_index, model_info->operands, + mojom_operator->output_index, mojom_operator->output_scale_index, + mojom_operator->output_zero_point_index); + break; + } + /* + // TODO:::Implement + + case OperationInfo::Tag::kGru: + model_info->AddGru(op); + break; + case OperationInfo::Tag::kGruCell: + model_info->AddGruCell(op); + break; + case OperationInfo::Tag::kLstm: + model_info->AddLstm(op); + break; + case OperationInfo::Tag::kLstmCell: + model_info->AddLstmCell(op); + break; + */ + default: NOTREACHED(); } @@ -1109,15 +2217,37 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { AddOutput(std::move(output->name), output->index); } - // Finish the graph build. + // Post CompileGraph task to thread pool rather than run in GPU main thread to + // avoid blocking. The OnGraphCompiled task will run back on the current GPU main thread. + base::ThreadPool::PostTaskAndReply( + FROM_HERE, + base::BindOnce(&GraphDMLImpl::CompileGraph, base::Unretained(this)), + base::BindOnce(&GraphDMLImpl::OnGraphCompiled, base::Unretained(this), + std::move(callback), std::move(constant_resource), + std::move(constants_info))); +} + +// Since IDMLDevice1::CompileGraph called in this method need long time to +// compile shaders (if not cached before), this method may block current thread. +void GraphDMLImpl::CompileGraph() { mCompiledOperator = graph_desc_builder_->Compile(DML_EXECUTION_FLAG_NONE); +} + +void GraphDMLImpl::OnGraphCompiled( + BuildCallback callback, + ComPtr constant_resource, + ConstantsInfoPtr constants_info) { + if (!mCompiledOperator) { + std::move(callback).Run(BuildResult::kUnknownError); + return; + } auto input_nodes = graph_desc_builder_->GetInputNodes(); std::vector input_buffer_binding(input_nodes.size()); for (size_t i = 0; i < input_nodes.size(); ++i) { auto input = input_nodes[i]; if (input.type == NodeType::kConstant) { - input_buffer_binding[i].Buffer = constants_resource->GetResource(); + input_buffer_binding[i].Buffer = constant_resource->GetResource(); auto& memory_info = constants_info->memory_info[input.object_id]; input_buffer_binding[i].Offset = memory_info->byte_offset; input_buffer_binding[i].SizeInBytes = memory_info->byte_length; @@ -1132,43 +2262,47 @@ bool GraphDMLImpl::Build(ModelInfoPtr model_info, BuildResult* out_result) { execution_context_->InitializeGraph(this, mCompiledOperator.Get(), input_binding_desc); - execution_context_->Flush(); - execution_context_->WaitForSignal(); - execution_context_->ReleaseCompletedResources(); + execution_context_->WaitForSignal( + base::BindOnce(&GraphDMLImpl::OnWaitForBuildSignal, + base::Unretained(this), std::move(callback))); +} + +void GraphDMLImpl::OnWaitForBuildSignal(BuildCallback callback) { + execution_context_->ReleaseCompletedResources(); auto& named_outputs = graph_desc_builder_->GetNamedOutputs(); HRESULT hr = output_resource_readback_->InitializeResource(named_outputs); if (FAILED(hr)) { - *out_result = BuildResult::kUnknownError; - return false; + std::move(callback).Run(BuildResult::kUnknownError); + return; } - *out_result = BuildResult::kOk; - return true; + std::move(callback).Run(BuildResult::kOk); } void GraphDMLImpl::Compute(NamedResourcesPtr named_inputs, ComputeCallback callback) { + TRACE_EVENT0("gpu", "GraphDMLImpl::Compute"); ExecutionResources* execution_resources = execution_context_->GetExecutionResources(); - ID3D12Resource* inputs_resource = + ID3D12Resource* input_resource = execution_resources->GetResource(this, ResourceType::kInput); - if (inputs_resource == nullptr) { + if (input_resource == nullptr) { base::ReadOnlySharedMemoryRegion& shared_memory_region = named_inputs->shared_memory; DCHECK(shared_memory_region.IsValid()); size_t inputs_byte_length = shared_memory_region.GetSize(); - inputs_resource = execution_resources->Allocate(ResourceType::kInput, + input_resource = execution_resources->Allocate(ResourceType::kInput, inputs_byte_length, this); } - input_resource_uploader_->UploadInputs(inputs_resource, named_inputs); + input_resource_uploader_->UploadInputs(input_resource, named_inputs); auto input_nodes = graph_desc_builder_->GetInputNodes(); std::vector input_buffer_binding(input_nodes.size()); std::vector input_binding_desc(input_nodes.size()); for (size_t i = 0; i < input_nodes.size(); ++i) { auto input = input_nodes[i]; if (input.type == NodeType::kInput) { - input_buffer_binding[i].Buffer = inputs_resource; + input_buffer_binding[i].Buffer = input_resource; auto& memory_info = named_inputs->resources[input.name]; input_buffer_binding[i].Offset = memory_info->byte_offset; input_buffer_binding[i].SizeInBytes = memory_info->byte_length; @@ -1180,119 +2314,88 @@ void GraphDMLImpl::Compute(NamedResourcesPtr named_inputs, ID3D12Resource* outputs_resource = execution_resources->GetResource(this, ResourceType::kOutput); + if (outputs_resource == nullptr) { - size_t outputs_resource_size = - output_resource_readback_->GetOutputsResourceSize(); outputs_resource = execution_resources->Allocate( - ResourceType::kOutput, outputs_resource_size, this); + ResourceType::kOutput, output_resource_readback_->outputs_resource_size(), this); } - auto& output_length_map = graph_desc_builder_->GetNamedOutputs(); - std::vector output_binding_desc(output_length_map.size()); + auto& output_info_map = graph_desc_builder_->GetNamedOutputs(); + std::vector output_binding_desc(output_info_map.size()); // The sort of the outputs from Graph Compute is different from the // outputs from Graph Build, so the offset need to be found the correct output // with name to read back from GPU buffer. base::flat_map output_buffer_binding; + // Reseve the map capacity to avoid reallocation. + output_buffer_binding.reserve(output_info_map.size()); uint64_t aligned_offset = 0; - size_t i = 0; - for (auto& [name, byte_length] : output_length_map) { + for (auto& [name, output_info] : output_info_map) { DML_BUFFER_BINDING buffer_binding; buffer_binding.Buffer = outputs_resource; buffer_binding.Offset = aligned_offset; - buffer_binding.SizeInBytes = byte_length; + buffer_binding.SizeInBytes = output_info.byte_length; output_buffer_binding[name] = buffer_binding; - output_binding_desc[i] = {DML_BINDING_TYPE_BUFFER, - &output_buffer_binding[name]}; - aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); - ++i; + output_binding_desc[output_info.index] = {DML_BINDING_TYPE_BUFFER, + &output_buffer_binding[name]}; + aligned_offset += + Align(output_info.byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); } execution_context_->ExecuteGraph(this, mCompiledOperator.Get(), input_binding_desc, output_binding_desc); - auto named_outputs = ml::webnn::mojom::NamedResources::New(); - HRESULT hr = output_resource_readback_->ReadResourceFromGpu(named_outputs, - outputs_resource); - if (FAILED(hr)) { - std::move(callback).Run(ComputeResult::kUnknownError, nullptr); - return; - } + // Copy buffer from GPU resource to CPU data. + execution_context_->CopyBufferRegion(output_resource_readback_->readback_resource_->GetResource(), + outputs_resource, output_resource_readback_->outputs_resource_size(), + D3D12_RESOURCE_STATE_COPY_SOURCE); - std::move(callback).Run(ComputeResult::kOk, std::move(named_outputs)); + execution_context_->Flush(); + execution_context_->WaitForSignal( + base::BindOnce(&GraphDMLImpl::OnWaitForComputeSignal, + base::Unretained(this), std::move(callback))); } -bool GraphDMLImpl::Compute(NamedResourcesPtr named_inputs, - ComputeResult* out_result, - NamedResourcesPtr* out_named_outputs) { - ExecutionResources* execution_resources = - execution_context_->GetExecutionResources(); - ID3D12Resource* inputs_resource = - execution_resources->GetResource(this, ResourceType::kInput); - if (inputs_resource == nullptr) { - base::ReadOnlySharedMemoryRegion& shared_memory_region = - named_inputs->shared_memory; - DCHECK(shared_memory_region.IsValid()); - size_t inputs_byte_length = shared_memory_region.GetSize(); - inputs_resource = execution_resources->Allocate(ResourceType::kInput, - inputs_byte_length, this); - } - input_resource_uploader_->UploadInputs(inputs_resource, named_inputs); - auto input_nodes = graph_desc_builder_->GetInputNodes(); - std::vector input_buffer_binding(input_nodes.size()); - std::vector input_binding_desc(input_nodes.size()); - for (size_t i = 0; i < input_nodes.size(); ++i) { - auto input = input_nodes[i]; - if (input.type == NodeType::kInput) { - input_buffer_binding[i].Buffer = inputs_resource; - auto& memory_info = named_inputs->resources[input.name]; - input_buffer_binding[i].Offset = memory_info->byte_offset; - input_buffer_binding[i].SizeInBytes = memory_info->byte_length; +void GraphDMLImpl::OnWaitForComputeSignal(ComputeCallback callback) { + named_outputs_ = ml::webnn::mojom::NamedResources::New(); + execution_context_->ReleaseCompletedResources(); - input_binding_desc[i] = {DML_BINDING_TYPE_BUFFER, - &input_buffer_binding[i]}; - } + D3D12_RANGE tensorBufferRange{ + 0, output_resource_readback_->outputs_resource_size()}; + int8_t* readBackBuffer; + ID3D12Resource* readback_resource = + output_resource_readback_->readback_resource_->GetResource(); + HRESULT hr = readback_resource->Map( + 0, &tensorBufferRange, reinterpret_cast(&readBackBuffer)); + if (FAILED(hr)) { + std::move(callback).Run(ComputeResult::kUnknownError, + std::move(named_outputs_)); + return; } - ID3D12Resource* outputs_resource = - execution_resources->GetResource(this, ResourceType::kOutput); - if (outputs_resource == nullptr) { - size_t outputs_resource_size = - output_resource_readback_->GetOutputsResourceSize(); - outputs_resource = execution_resources->Allocate( - ResourceType::kOutput, outputs_resource_size, this); - } - auto& output_length_map = graph_desc_builder_->GetNamedOutputs(); - std::vector output_binding_desc(output_length_map.size()); - // The sort of the outputs from Graph Compute is different from the - // outputs from Graph Build, so the offset need to be found the correct output - // with name to read back from GPU buffer. - base::flat_map output_buffer_binding; - uint64_t aligned_offset = 0; - size_t i = 0; - for (auto& [name, byte_length] : output_length_map) { - DML_BUFFER_BINDING buffer_binding; - buffer_binding.Buffer = outputs_resource; - buffer_binding.Offset = aligned_offset; - buffer_binding.SizeInBytes = byte_length; - output_buffer_binding[name] = buffer_binding; - output_binding_desc[i] = {DML_BINDING_TYPE_BUFFER, - &output_buffer_binding[name]}; - aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); - ++i; + for (auto& [name, memory_info] : + output_resource_readback_->outputs_info_map_) { + auto mojo_memory_info = ml::webnn::mojom::MemoryInfo::New(); + size_t byte_offset = memory_info.byte_offset; + size_t byte_length = memory_info.byte_length; + mojo_memory_info->byte_offset = byte_offset; + mojo_memory_info->byte_length = byte_length; + named_outputs_->resources[name] = std::move(mojo_memory_info); + + uint8_t* address = output_resource_readback_->outputs_shm_region_.mapping + .GetMemoryAs() + + byte_offset; + memcpy(address, readBackBuffer + byte_offset, byte_length); } + named_outputs_->shared_memory = + output_resource_readback_->outputs_shm_region_.region.Duplicate(); - execution_context_->ExecuteGraph(this, mCompiledOperator.Get(), - input_binding_desc, output_binding_desc); + readback_resource->Unmap(0, nullptr); - *out_named_outputs = ml::webnn::mojom::NamedResources::New(); - HRESULT hr = output_resource_readback_->ReadResourceFromGpu( - *out_named_outputs, outputs_resource); if (FAILED(hr)) { - *out_result = ComputeResult::kUnknownError; - *out_named_outputs = nullptr; - return false; + std::move(callback).Run(ComputeResult::kUnknownError, + std::move(named_outputs_)); + return; } - *out_result = ComputeResult::kOk; - return true; + std::move(callback).Run(ComputeResult::kOk, std::move(named_outputs_)); } } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/graph_dml_impl.h b/content/browser/ml/webnn/dml/graph_dml_impl.h index 88b172c245490d..6131fc4b22590d 100644 --- a/content/browser/ml/webnn/dml/graph_dml_impl.h +++ b/content/browser/ml/webnn/dml/graph_dml_impl.h @@ -31,13 +31,14 @@ namespace content::webnn { namespace { using Microsoft::WRL::ComPtr; +using ml::webnn::mojom::OperandType; +using ml::webnn::mojom::OperatorType; using ml::webnn::mojom::BuildResult; using ml::webnn::mojom::ClampOptions; using ml::webnn::mojom::ClampOptionsPtr; using ml::webnn::mojom::ComputeResult; using ml::webnn::mojom::ConstantsInfoPtr; using ml::webnn::mojom::Conv2dOptionsPtr; -using ml::webnn::mojom::ElementWiseBinaryType; using ml::webnn::mojom::GemmOptionsPtr; using ml::webnn::mojom::ModelInfoPtr; using ml::webnn::mojom::NamedResourcesPtr; @@ -46,7 +47,8 @@ using ml::webnn::mojom::OperationInfo; using ml::webnn::mojom::OperationInfoPtr; using ml::webnn::mojom::Pool2dOptions; using ml::webnn::mojom::Pool2dOptionsPtr; -using ml::webnn::mojom::Pool2dType; +using ml::webnn::mojom::InputOperandLayout; +using ml::webnn::mojom::PaddingMode; } // namespace @@ -71,24 +73,40 @@ class GraphDMLImpl : public ml::webnn::mojom::Graph { void AddClamp(UINT64 input_index, ClampOptionsPtr options, UINT64 output_index); - void AddConv2d(UINT64 input_index, + void AddConv2d(OperatorType operator_type, + UINT64 input_index, UINT64 filter_index, + UINT64 input_zero_point_index, + UINT64 filter_zero_point_index, Conv2dOptionsPtr options, OperandDescriptorPtr desc, UINT64 output_index); - void AddElementWiseBinary(UINT64, - UINT64, - ElementWiseBinaryType, - OperandDescriptorPtr, + void AddElementWiseUnary(OperatorType operator_type, + UINT64 input_index, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddElementWiseUnaryTwoParameter(OperatorType operator_type, + UINT64 input_index, + float first_parameter, + float second_parameter, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddElementWiseBinary(OperatorType operator_type, + UINT64 a_index, + UINT64 b_index, + OperandDescriptorPtr output_desc, UINT64 output_index); - void AddGemm(UINT64, - UINT64, + void AddGemm(OperatorType operator_type, + UINT64 a_index, + UINT64 b_index, + UINT64 a_index_zero_point, + UINT64 b_index_zero_point, GemmOptionsPtr, OperandDescriptorPtr, UINT64 output_index); - void AddPool2d(UINT64 input_index, + void AddPool2d(OperatorType operator_type, + UINT64 input_index, Pool2dOptionsPtr options, - Pool2dType type, OperandDescriptorPtr desc, UINT64 output_index); void AddRelu(UINT64 input_index, @@ -100,14 +118,98 @@ class GraphDMLImpl : public ml::webnn::mojom::Graph { void AddSoftmax(UINT64 input_index, OperandDescriptorPtr desc, UINT64 output_index); + void AddArgMinMax(OperatorType operator_type, + UINT64 input_index, + uint32_t axis, + bool select_last_index, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddCast(UINT64 input_index, + OperandType data_type, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddConcat(base::span input_indices, + uint32_t axis, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddSlice(UINT64 input_index, + base::span starts, + base::span sizes, + OperandDescriptorPtr desc, + UINT64 output_index); + void AddSplit(UINT64 output_index, + uint32_t axis, + base::flat_map& operands, + base::span output_indices); + void AddExpand(UINT64 input_index, OperandDescriptorPtr output_desc, UINT64 output_index); + void AddGather(UINT64 input_index, + UINT64 indices_index, + uint32_t axis, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddInstanceNormalization(uint64_t input_index, + uint64_t scale_index, + uint64_t bias_index, + float epsilon, + InputOperandLayout operand_layout, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddMeanVarianceNormalization(uint64_t input_index, + uint64_t mean_index, + uint64_t variance_index, + uint64_t scale_index, + uint64_t bias_index, + float epsilon, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddPad(UINT64 input_index, + PaddingMode padding_mode, + float value, + base::span beginning_padding, + base::span ending_padding, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddFillSequence(float start, + float step, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddReduce(UINT64 input_index, + OperatorType operator_type, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddResample2d(UINT64 input_index, + ml::webnn::mojom::InterpolationMode interpolation_mode, + base::span scales, + base::span axes, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddTranspose(UINT64 input_index, + base::span permutation, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddTriangularMatrix(UINT64 input_index, OperandDescriptorPtr desc, UINT64 output_index); + void AddElementWiseIf(UINT64 condition_index, + UINT64 true_value_index, + UINT64 false_value_index, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddDequantizeLinear(UINT64 input_index, + UINT64 scale_index, + UINT64 zero_point_index, + OperandDescriptorPtr output_desc, + UINT64 output_index); + void AddDynamicQuantizeLinear( + UINT64 input_index, + base::flat_map& operands, + UINT64 output_index, + UINT64 output_scale_index, + UINT64 output_zero_point_index); void Build(ModelInfoPtr model_info, BuildCallback callback) override; - bool Build(ModelInfoPtr model_info, BuildResult* out_result) override; void Compute(NamedResourcesPtr named_inputs, ComputeCallback callback) override; - bool Compute(NamedResourcesPtr named_inputs, - ComputeResult* out_result, - NamedResourcesPtr* out_named_outputs) override; std::unique_ptr Clamp(NodeOutput* input_node, const ClampOptions* options); void EmulateFusedOperator(const OperationInfo* activation, @@ -117,7 +219,16 @@ class GraphDMLImpl : public ml::webnn::mojom::Graph { const std::vector& nchwOutputDims); void AddOutput(const std::string&, UINT64); + void CompileGraph(); + void OnGraphCompiled(BuildCallback callback, + ComPtr constant_resource, + ConstantsInfoPtr constants_info); + + void OnWaitForBuildSignal(BuildCallback callback); + void OnWaitForComputeSignal(ComputeCallback callback); + scoped_refptr execution_context_; + std::unique_ptr constant_resource_uploader_; std::unique_ptr input_resource_uploader_; std::unique_ptr output_resource_readback_; std::unique_ptr graph_desc_builder_; @@ -130,6 +241,8 @@ class GraphDMLImpl : public ml::webnn::mojom::Graph { std::string error_messages_; BuildResult build_result_; + + NamedResourcesPtr named_outputs_; }; } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/graph_node_output.h b/content/browser/ml/webnn/dml/graph_node_output.h index ecde5bee45c13c..423b8487ea8240 100644 --- a/content/browser/ml/webnn/dml/graph_node_output.h +++ b/content/browser/ml/webnn/dml/graph_node_output.h @@ -20,7 +20,7 @@ enum class NodeType { // Constant operand is also input for DirectML Graph builder. kConstant = 1, kOperator = 2, - kUnknow = 3, + kUnknown = 3, }; struct Node { @@ -44,7 +44,7 @@ struct InputNode final : public Node { std::string name; // The object id identify constant node which is to find memory info when // binding input for initializing graph. - uint32_t object_id; + uint32_t object_id = 0; }; struct OperatorNode final : public Node { diff --git a/content/browser/ml/webnn/dml/graph_tensor_desc.cc b/content/browser/ml/webnn/dml/graph_tensor_desc.cc index cbcd7514b3a8f2..7e8a9b212c8b88 100644 --- a/content/browser/ml/webnn/dml/graph_tensor_desc.cc +++ b/content/browser/ml/webnn/dml/graph_tensor_desc.cc @@ -6,21 +6,39 @@ #include "base/check_op.h" #include "base/numerics/checked_math.h" +#include "base/containers/span.h" namespace content::webnn { namespace { size_t GetBytesOfDataType(DML_TENSOR_DATA_TYPE data_type) { + static_assert(sizeof(float) == 4, "DirectML expects to run on machines with 4-byte floats."); + static_assert(sizeof(double) == 8, "DirectML expects to run on machines with 8-byte doubles."); + switch (data_type) { + case DML_TENSOR_DATA_TYPE_FLOAT16: + return sizeof(uint16_t); case DML_TENSOR_DATA_TYPE_FLOAT32: return sizeof(float); - case DML_TENSOR_DATA_TYPE_FLOAT16: + case DML_TENSOR_DATA_TYPE_FLOAT64: + return sizeof(double); + case DML_TENSOR_DATA_TYPE_UINT8: + return sizeof(uint8_t); + case DML_TENSOR_DATA_TYPE_INT8: + return sizeof(int8_t); + case DML_TENSOR_DATA_TYPE_UINT16: return sizeof(uint16_t); - case DML_TENSOR_DATA_TYPE_INT32: - return sizeof(int32_t); + case DML_TENSOR_DATA_TYPE_INT16: + return sizeof(int16_t); case DML_TENSOR_DATA_TYPE_UINT32: return sizeof(uint32_t); + case DML_TENSOR_DATA_TYPE_INT32: + return sizeof(int32_t); + case DML_TENSOR_DATA_TYPE_UINT64: + return sizeof(uint64_t); + case DML_TENSOR_DATA_TYPE_INT64: + return sizeof(int64_t); default: return 0; } @@ -79,6 +97,24 @@ absl::optional TotalTensorSizeInBytes( return checked_total_size_in_bytes.ValueOrDie(); } +void InsertPaddingOnes(/*inout*/ std::vector& values, + size_t minimum_size, + TensorDesc::Alignment alignment) { + // Insert's enough 1's to satisfy the minimum size. + // If already large enough, no additional 1's are added. + size_t old_size = values.size(); + size_t new_size = std::max(minimum_size, old_size); + size_t filler_count = new_size - old_size; + + // Insert filler values on: + // the leading edge if trailing aligned: [4,5] -> [1,1,1,4,5] + // the trailing edge if leading aligned: [4,5] -> [4,5,1,1,1] + auto insertion_point = (alignment == TensorDesc::Alignment::kTrailing) + ? values.begin() + : values.end(); + values.insert(insertion_point, filler_count, 1u); +} + } // namespace TensorDesc::TensorDesc() = default; @@ -106,28 +142,68 @@ void TensorDesc::Initialize(DML_TENSOR_DATA_TYPE data_type, DML_TENSOR_FLAGS flags, std::vector dimensions, absl::optional> strides) { - DCHECK(!dimensions.empty() && - dimensions.size() < DML_TENSOR_DIMENSION_COUNT_MAX); + DCHECK(dimensions.size() <= DML_TENSOR_DIMENSION_COUNT_MAX1); DCHECK(!strides || dimensions.size() == strides->size()); dimensions_ = std::move(dimensions); strides_ = std::move(strides); + // DML (as of at least 1.11) requires dimension count to be at least 1 + // because otherwise validation during operator creation will complain with + // E_INVALIDARG. So scalars must be conveyed with dimensions = [1]. + EnsureMinimumRank(1u, Alignment::kTrailing); + + // Round up to the nearest 4 bytes. + // The buffer allocation already aligned chunks up to + // DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT. + uint64_t minimum_implied_size_in_bytes = + TotalTensorSizeInBytes(data_type, dimensions_, strides_).value(); + minimum_implied_size_in_bytes = (minimum_implied_size_in_bytes + 3) & ~3ull; + buffer_desc_.DimensionCount = dimensions_.size(); buffer_desc_.Sizes = dimensions_.data(); buffer_desc_.Strides = strides_ ? strides_.value().data() : nullptr; - buffer_desc_.TotalTensorSizeInBytes = - TotalTensorSizeInBytes(data_type, dimensions_, strides_).value(); + buffer_desc_.TotalTensorSizeInBytes = minimum_implied_size_in_bytes; buffer_desc_.GuaranteedBaseOffsetAlignment = 0; buffer_desc_.DataType = data_type; buffer_desc_.Flags = flags; } +TensorDesc::TensorDesc(TensorDesc const& other) + : dimensions_(other.dimensions_), + strides_(other.strides_), + buffer_desc_(other.buffer_desc_), + tensor_desc_(other.tensor_desc_) { + // Update the internal self-referential pointers. + buffer_desc_.Sizes = dimensions_.data(); + buffer_desc_.Strides = strides_ ? strides_.value().data() : nullptr; +} + TensorDesc::TensorDesc(TensorDesc&& other) = default; + TensorDesc& TensorDesc::operator=(TensorDesc&& other) = default; +TensorDesc& TensorDesc::operator=(const TensorDesc& other) { + dimensions_ = other.dimensions_; + strides_ = other.strides_; + buffer_desc_ = other.buffer_desc_; + tensor_desc_ = other.tensor_desc_; + + // Update the internal self-referential pointers. + buffer_desc_.Sizes = dimensions_.data(); + buffer_desc_.Strides = strides_ ? strides_.value().data() : nullptr; + + return *this; +} + TensorDesc::~TensorDesc() = default; DML_TENSOR_DESC* TensorDesc::Get() { + DCHECK(buffer_desc_.Sizes == dimensions_.data()); + DCHECK(buffer_desc_.Strides == nullptr || buffer_desc_.Strides == strides_.value().data()); + + // Refresh the pointers to avoid them being stale after move + // or copy construction. + if (buffer_desc_.DataType == DML_TENSOR_DATA_TYPE_UNKNOWN) { return nullptr; } @@ -151,8 +227,150 @@ absl::optional>& TensorDesc::GetStrides() { return strides_; } +std::vector TensorDesc::GetStridesOrDefaultStrides() const { + return strides_ ? *strides_ : ComputeDecreasingStrides(dimensions_); +} + +void TensorDesc::EnsureStridesExist() { + if (!strides_) + { + std::vector new_strides = ComputeDecreasingStrides(dimensions_); + strides_ = std::move(new_strides); + buffer_desc_.Strides = strides_.value().data(); + } +} + UINT64 TensorDesc::GetTotalTensorSizeInBytes() { return buffer_desc_.TotalTensorSizeInBytes; } +void TensorDesc::EnsureMinimumRank(size_t minimum_rank, Alignment alignment) +{ + if (dimensions_.size() < minimum_rank) { + // Note this does not change the TotalTensorSizeInBytes, since leading 1's + // make no difference, nor any other field. + InsertPaddingOnes(/*inout*/ dimensions_, minimum_rank, alignment); + buffer_desc_.DimensionCount = dimensions_.size(); + buffer_desc_.Sizes = dimensions_.data(); + + if (strides_) { + InsertPaddingOnes(/*inout*/ *strides_, minimum_rank, alignment); + buffer_desc_.Strides = strides_.value().data(); + } + } +} + +std::vector TensorDesc::ComputeDecreasingStrides(base::span dimensions) +{ + auto dimension_count = dimensions.size(); + std::vector strides(dimension_count); + + uint32_t stride = 1; + for (auto i = dimension_count; i-- > 0; ) + { + strides[i] = stride; + stride *= dimensions[i]; + } + + return strides; +} + +void TensorDesc::PermuteDimensions(base::span permutation, Alignment alignment) +{ + size_t permutation_rank = permutation.size(); + + // Ensure there are enough elements to apply the permutation by adding + // leading 1's if necessary. + EnsureMinimumRank(permutation_rank, alignment); + + // Compute strides *before* the reordering, since callers use this function + // to rearrange the dimensions (depending on the affinity of the backend + // for a certain preference - e.g. NHWC vs NCHW) and need the strides to + // be adjusted accordingly. Otherwise strides would be computed using + // the permuted dimensions and read the wrong elements. + EnsureStridesExist(); + + // If there are more dimensions than permutation the size (e.g. maybe + // they were already padded out to some limit like 4D with leading 1's) + // then the permutation only applies to the significant portion, + // depending on alignment. Any additional leading/traling batch dimensions + // are ignorable. + + size_t subset_offset = (alignment == Alignment::kLeading) + ? 0 + : dimensions_.size() - permutation_rank; + auto dimensions_span = base::span(dimensions_); + auto strides_span = base::span(*strides_); + dimensions_span = dimensions_span.subspan(subset_offset, permutation_rank); + strides_span = strides_span.subspan(subset_offset, permutation_rank); + + auto permute = [](/*inout*/ base::span values, + base::span permutation) { + DCHECK(values.size() == permutation.size()); + + // Gather the original values via the permutation. + std::vector temporary_values(values.begin(), values.end()); + for (size_t i = 0, count = values.size(); i < count; ++i) { + values[i] = temporary_values[permutation[i]]; + } + }; + + permute(/*inout*/ dimensions_span, permutation); + permute(/*inout*/ strides_span, permutation); +} + +void TensorDesc::BroadcastTo(base::span broadcast_dimensions, + Alignment alignment, + size_t ignorable_tail_count) { + size_t broadcast_rank = broadcast_dimensions.size(); + EnsureMinimumRank(broadcast_rank, alignment); + EnsureStridesExist(); + + // Determine the window of dimensions and strides that are to be modified. + // e.g. + // Alignment = trailing/right + // Ignorable tail count = 0 + // Original dimensions = [2,1,4] + // Original strides = [4,4,1] + // Broadcast dimensions = [5,2,3,4] + // New dimensions = [5,2,3,4] + // New strides = [0,4,0,1] + // e.g. + // Alignment = trailing/right + // Ignorable tail count = 2 + // Original dimensions = [2,1,4] + // Original strides = [4,4,1] + // Broadcast dimensions = [5,2,3,4] + // New dimensions = [5,2,1,4] + // New strides = [0,4,4,1] + // e.g. + // Alignment = leading/left + // Ignorable tail count = 0 + // Original dimensions = [3,1,4] + // Original strides = [4,4,1] + // Broadcast dimensions = [3,2] + // New dimensions = [3,2,4] + // New strides = [4,0,1] + + size_t subset_offset = (alignment == Alignment::kLeading) + ? 0 + : dimensions_.size() - broadcast_rank; + auto dimensions_span = base::span(dimensions_); + auto strides_span = base::span(*strides_); + dimensions_span = dimensions_span.subspan(subset_offset, broadcast_rank); + strides_span = strides_span.subspan(subset_offset, broadcast_rank); + size_t clamped_rank = + broadcast_rank - std::min(broadcast_rank, ignorable_tail_count); + + for (size_t i = 0; i < clamped_rank; ++i) { + // Any 1-size dimensions get promoted to the target broadcast dimension + // and have their stride set to 0 for projection. + if (dimensions_span[i] == 1u) { + dimensions_span[i] = broadcast_dimensions[i]; + strides_span[i] = 0; + } + } +} + + } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/graph_tensor_desc.h b/content/browser/ml/webnn/dml/graph_tensor_desc.h index f196b820b9bf84..d77a4ce73255a1 100644 --- a/content/browser/ml/webnn/dml/graph_tensor_desc.h +++ b/content/browser/ml/webnn/dml/graph_tensor_desc.h @@ -10,6 +10,7 @@ #include #include "DirectML.h" #include "third_party/abseil-cpp/absl/types/optional.h" +#include "base/containers/span.h" namespace content::webnn { @@ -17,6 +18,12 @@ using Microsoft::WRL::ComPtr; class TensorDesc final { public: + enum Alignment : uint32_t + { + kLeading, // Align to leading/left edge. e.g. the 1 in [1,2,3] + kTrailing, // Align to trailing/right edge. e.g. the 3 in [1,2,3] + }; + TensorDesc(); TensorDesc(DML_TENSOR_DATA_TYPE data_type, std::vector dimensions); TensorDesc(DML_TENSOR_DATA_TYPE data_type, @@ -26,26 +33,62 @@ class TensorDesc final { DML_TENSOR_FLAGS flags, std::vector dimensions, absl::optional> strides); + TensorDesc(const TensorDesc& other); TensorDesc(TensorDesc&& other); TensorDesc& operator=(TensorDesc&& other); + TensorDesc& operator=(const TensorDesc& other); ~TensorDesc(); DML_TENSOR_DESC* Get(); DML_TENSOR_DATA_TYPE GetDataType() const; DML_TENSOR_FLAGS GetFlags() const; std::vector& GetDimensions(); + + // Returns the strides, or empty if none exist. absl::optional>& GetStrides(); + + // Returns valid strides, either the explicit ones contained or the generated + // ones (packed with no padding in descending order left-to-right) if empty. + std::vector GetStridesOrDefaultStrides() const; + + // Ensures strides are not empty, computing them from dimensions if needed. + void EnsureStridesExist(); + + // Ensures the rank is at least the minimum rank, filling the opposite side + // (depending on alignment) with 1's when needed. + // e.g. [5,6] with minimum rank of 4 yields [1,1,5,6]. + void EnsureMinimumRank(size_t minimum_rank, Alignment alignment); + + // Permute the original dimensions/strides to the given remapping. + // e.g. dimensions [5,6,7,8] with permutation [3,2,0,1] yields dimensions + // of [8,7,5,6]. All indices must be within [0, permutation.size() - 1]. + // A permutation larger than the current rank will increase the rank first. + void PermuteDimensions(base::span permutation, + Alignment alignment); + + void BroadcastTo(base::span dimensions, + Alignment alignment, + size_t ignorable_tail_count = 0); + UINT64 GetTotalTensorSizeInBytes(); + // Returns the default decreasing order packed strides for the given dimensions. + // e.g. dimensions [1,2,3,4] yields strides [24,12,4,1]. + // See https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml-helper-functions#calculatestrides. + static std::vector ComputeDecreasingStrides( + base::span dimensions); + private: void Initialize(DML_TENSOR_DATA_TYPE data_type, DML_TENSOR_FLAGS flags, std::vector dimensions, absl::optional> strides); - // DML_BUFFER_TENSOR_DESC only has a pointer of dimensions and strides, the - // data hold in dimensions_ and strides_. + + // DML_BUFFER_TENSOR_DESC only has a pointer to dimensions and strides, + // which points to dimensions_ and strides_. std::vector dimensions_; absl::optional> strides_; + // Describes a tensor that will be stored in a Direct3D 12 buffer resource. DML_BUFFER_TENSOR_DESC buffer_desc_ = {}; DML_TENSOR_DESC tensor_desc_; diff --git a/content/browser/ml/webnn/dml/mojo_server_dml_impl.cc b/content/browser/ml/webnn/dml/mojo_server_dml_impl.cc index e2703f595bca39..130a1220277ca2 100644 --- a/content/browser/ml/webnn/dml/mojo_server_dml_impl.cc +++ b/content/browser/ml/webnn/dml/mojo_server_dml_impl.cc @@ -35,7 +35,7 @@ MojoServerDMLImpl::MojoServerDMLImpl(WebnnServiceDMLImpl* webnn_service) void MojoServerDMLImpl::CreateContext( ContextOptionsPtr options, MojoServer::CreateContextCallback callback) { - auto adapter = webnn_service_->RequestAdapter(options->power_preference); + auto adapter = webnn_service_->RequestAdapter(options->device_preference, options->power_preference); if (!adapter) { std::move(callback).Run(mojo::NullRemote()); return; @@ -63,7 +63,7 @@ void MojoServerDMLImpl::CreateContext( bool MojoServerDMLImpl::CreateContext( ContextOptionsPtr options, ::mojo::PendingRemote<::ml::webnn::mojom::Context>* out_remote) { - auto adapter = webnn_service_->RequestAdapter(options->power_preference); + auto adapter = webnn_service_->RequestAdapter(options->device_preference, options->power_preference); if (!adapter) { DLOG(ERROR) << "Failed to request the adapter."; return false; diff --git a/content/browser/ml/webnn/dml/readback_resource.cc b/content/browser/ml/webnn/dml/readback_resource.cc index 348ea93f56c873..112cab06102f34 100644 --- a/content/browser/ml/webnn/dml/readback_resource.cc +++ b/content/browser/ml/webnn/dml/readback_resource.cc @@ -6,6 +6,8 @@ #include +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" namespace content::webnn { @@ -16,16 +18,17 @@ ReadbackResource::ReadbackResource(ExecutionContext* execution_context) ReadbackResource::~ReadbackResource() = default; HRESULT ReadbackResource::InitializeResource( - std::map& named_outputs) { + std::map& named_outputs) { uint64_t aligned_offset = 0; - for (auto& [name, byte_length] : named_outputs) { + for (auto& [name, output_info] : named_outputs) { MemoryInfo memory_info = {}; memory_info.byte_offset = aligned_offset; - memory_info.byte_length = byte_length; + memory_info.byte_length = output_info.byte_length; outputs_info_map_[name] = memory_info; // Only offset need to be algnement, the byte length keep original value. - aligned_offset += Align(byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); + aligned_offset += + Align(output_info.byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); } outputs_resource_size_ = aligned_offset; outputs_shm_region_ = @@ -38,49 +41,6 @@ HRESULT ReadbackResource::InitializeResource( return S_OK; } -// Readback inference result from GPU that is stored in named_outputs. -HRESULT ReadbackResource::ReadResourceFromGpu(NamedResourcesPtr& named_outputs, - ID3D12Resource* src_resource) { - // Copy buffer from GPU resource to CPU data. - execution_context_->CopyBufferRegion(readback_resource_->GetResource(), - src_resource, outputs_resource_size_, - D3D12_RESOURCE_STATE_COPY_SOURCE); - - execution_context_->Flush(); - execution_context_->WaitForSignal(); - execution_context_->ReleaseCompletedResources(); - - D3D12_RANGE tensorBufferRange{0, outputs_resource_size_}; - int8_t* readBackBuffer; - HRESULT hr = readback_resource_->Map( - 0, &tensorBufferRange, reinterpret_cast(&readBackBuffer)); - if (FAILED(hr)) { - return hr; - } - - for (auto& [name, memory_info] : outputs_info_map_) { - auto mojo_memory_info = ml::webnn::mojom::MemoryInfo::New(); - size_t byte_offset = memory_info.byte_offset; - size_t byte_length = memory_info.byte_length; - mojo_memory_info->byte_offset = byte_offset; - mojo_memory_info->byte_length = byte_length; - named_outputs->resources[name] = std::move(mojo_memory_info); - - std::vector output_buffer(byte_length); - uint8_t* address = - outputs_shm_region_.mapping.GetMemoryAs() + byte_offset; - memcpy(address, readBackBuffer + byte_offset, byte_length); - } - named_outputs->shared_memory = outputs_shm_region_.region.Duplicate(); - - readback_resource_->Unmap(0, nullptr); - return S_OK; -} - -size_t ReadbackResource::GetOutputsResourceSize() const { - return outputs_resource_size_; -} - ReadbackResource::MemoryInfo::MemoryInfo() = default; ReadbackResource::MemoryInfo::~MemoryInfo() = default; diff --git a/content/browser/ml/webnn/dml/readback_resource.h b/content/browser/ml/webnn/dml/readback_resource.h index 9fdc73ab72bb9b..a6cd5776882d85 100644 --- a/content/browser/ml/webnn/dml/readback_resource.h +++ b/content/browser/ml/webnn/dml/readback_resource.h @@ -11,24 +11,30 @@ #include "DirectML.h" #include "components/ml/mojom/webnn_graph.mojom.h" #include "content/browser/ml/webnn/dml/gpgmm_d3d12.h" +#include "content/browser/ml/webnn/dml/graph_desc_builder.h" #include "content/browser/ml/webnn/dml/utils_dml.h" namespace content::webnn { using Microsoft::WRL::ComPtr; using ml::webnn::mojom::NamedResourcesPtr; +using ml::webnn::mojom::ComputeResult; + +using ComputeCallback = base::OnceCallback; class ExecutionContext; +class GraphDMLImpl; class ReadbackResource final { public: explicit ReadbackResource(ExecutionContext* execution_context); ~ReadbackResource(); - HRESULT InitializeResource(std::map& named_outputs); - HRESULT ReadResourceFromGpu(NamedResourcesPtr& named_outputs, - ID3D12Resource* src_resource); - size_t GetOutputsResourceSize() const; + HRESULT InitializeResource(std::map& named_outputs); + + size_t outputs_resource_size() const { return outputs_resource_size_; } + + friend class GraphDMLImpl; private: struct MemoryInfo { diff --git a/content/browser/ml/webnn/dml/upload_resource.cc b/content/browser/ml/webnn/dml/upload_resource.cc index 00bd035bf28580..cc6160711073d4 100644 --- a/content/browser/ml/webnn/dml/upload_resource.cc +++ b/content/browser/ml/webnn/dml/upload_resource.cc @@ -6,6 +6,8 @@ #include +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "content/browser/ml/webnn/dml/execution_context.h" namespace content::webnn { @@ -14,13 +16,12 @@ namespace { using ml::webnn::mojom::MemoryInfoPtr; -template HRESULT UploadResourceToGpu( ExecutionContext* execution_context, ID3D12Resource* dst_resource, ID3D12Resource* src_resource, - base::ReadOnlySharedMemoryRegion& shared_memory_region, - T& named_inputs) { + base::ReadOnlySharedMemoryMapping& shared_memory_mapping, + size_t byte_length) { // Map the upload heap and copy the source data into it. A null pointer // indicates the entire subresource might be read by the CPU. void* upload_data = nullptr; @@ -28,22 +29,12 @@ HRESULT UploadResourceToGpu( if (FAILED(hr)) { return hr; } - - for (auto& [_, memory_info] : named_inputs) { - uint64_t byte_length = memory_info->byte_length; - uint64_t byte_offset = memory_info->byte_offset; - DCHECK(byte_offset % DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT == 0); - DCHECK(shared_memory_region.IsValid()); - base::ReadOnlySharedMemoryMapping shared_memory_mapping = - shared_memory_region.MapAt(memory_info->byte_offset, byte_length); - memcpy(static_cast(upload_data) + memory_info->byte_offset, - shared_memory_mapping.GetMemoryAs(), byte_length); - } + memcpy(static_cast(upload_data), + shared_memory_mapping.GetMemoryAs(), byte_length); src_resource->Unmap(0, nullptr); // Copy from the upload heap into the destination resource - execution_context->CopyBufferRegion(dst_resource, src_resource, - shared_memory_region.GetSize(), + execution_context->CopyBufferRegion(dst_resource, src_resource, byte_length, D3D12_RESOURCE_STATE_COPY_DEST); return S_OK; @@ -60,9 +51,14 @@ UploadResource::~UploadResource() = default; // need to transition. HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, ConstantsInfoPtr& constants_info) { + TRACE_EVENT0("gpu", "UploadResource::UploadConstants"); base::ReadOnlySharedMemoryRegion& shared_memory_region = constants_info->shared_memory; size_t constants_byte_length = shared_memory_region.GetSize(); + if (!shm_mapping_.IsValid()) { + shm_mapping_ = shared_memory_region.Map(); + DCHECK(shm_mapping_.IsValid()); + } HRESULT hr = S_OK; if (upload_resource_ == nullptr) { @@ -73,16 +69,21 @@ HRESULT UploadResource::UploadConstants(ID3D12Resource* dst_resource, } DCHECK(upload_resource_ != nullptr); - return UploadResourceToGpu>( - execution_context_, dst_resource, upload_resource_->GetResource(), - shared_memory_region, constants_info->memory_info); + return UploadResourceToGpu(execution_context_, dst_resource, + upload_resource_->GetResource(), shm_mapping_, + constants_byte_length); } HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, NamedResourcesPtr& named_inputs) { + TRACE_EVENT0("gpu", "UploadResource::UploadInputs"); base::ReadOnlySharedMemoryRegion& shared_memory_region = named_inputs->shared_memory; size_t inputs_byte_length = shared_memory_region.GetSize(); + if (!shm_mapping_.IsValid()) { + shm_mapping_ = shared_memory_region.Map(); + DCHECK(shm_mapping_.IsValid()); + } HRESULT hr = S_OK; if (upload_resource_ == nullptr) { @@ -93,9 +94,9 @@ HRESULT UploadResource::UploadInputs(ID3D12Resource* dst_resource, } DCHECK(upload_resource_ != nullptr); - return UploadResourceToGpu>( - execution_context_, dst_resource, upload_resource_->GetResource(), - shared_memory_region, named_inputs->resources); + return UploadResourceToGpu(execution_context_, dst_resource, + upload_resource_->GetResource(), shm_mapping_, + inputs_byte_length); } // Create entire memory for uploading resource that will be uploaded piece by diff --git a/content/browser/ml/webnn/dml/upload_resource.h b/content/browser/ml/webnn/dml/upload_resource.h index 1342df8ca9b3c8..64ae8a2d951873 100644 --- a/content/browser/ml/webnn/dml/upload_resource.h +++ b/content/browser/ml/webnn/dml/upload_resource.h @@ -35,6 +35,7 @@ class UploadResource final { ExecutionContext* execution_context_; ComPtr upload_resource_; + base::ReadOnlySharedMemoryMapping shm_mapping_; }; } // namespace content::webnn diff --git a/content/browser/ml/webnn/dml/utils_dml.h b/content/browser/ml/webnn/dml/utils_dml.h index 189024324e50c3..7d367eaaaa17f0 100644 --- a/content/browser/ml/webnn/dml/utils_dml.h +++ b/content/browser/ml/webnn/dml/utils_dml.h @@ -66,8 +66,8 @@ void ComputeImplicitPaddingForAutoPad(AutoPad auto_pad, template std::vector ComputeImplicitPaddingForAutoPad(const S* options, - std::vector inputSize, - std::vector filterSize) { + base::span inputSize, + base::span filterSize) { std::vector padding(4); ComputeImplicitPaddingForAutoPad( options->auto_pad, options->dilations[0], inputSize[0], filterSize[0], @@ -80,11 +80,11 @@ std::vector ComputeImplicitPaddingForAutoPad(const S* options, template std::vector ImplicitPadding(const T* options, - const std::vector& inputDims, - const std::vector& filterDims) { - return ComputeImplicitPaddingForAutoPad( - options, {inputDims[2], inputDims[3]}, - {filterDims[filterDims.size() - 2], filterDims[filterDims.size() - 1]}); + base::span inputDims, + base::span filterDims) { + auto inputSize = {inputDims[2], inputDims[3]}; + auto filterSize = {filterDims[filterDims.size() - 2], filterDims[filterDims.size() - 1]}; + return ComputeImplicitPaddingForAutoPad(options, inputSize, filterSize); } template diff --git a/content/browser/ml/webnn/dml/webnn_service_dml_impl.cc b/content/browser/ml/webnn/dml/webnn_service_dml_impl.cc index 10742e71558b6f..a114774127045a 100644 --- a/content/browser/ml/webnn/dml/webnn_service_dml_impl.cc +++ b/content/browser/ml/webnn/dml/webnn_service_dml_impl.cc @@ -6,32 +6,108 @@ #include #include +#include #include "base/logging.h" #include "base/no_destructor.h" #include "content/browser/ml/webnn/dml/mojo_server_dml_impl.h" +// TODO::: +#pragma optimize("", off) + namespace content::webnn { namespace { -std::map> EumerateAdapters() { +std::map> EnumerateAdapters(base::ScopedNativeLibrary& dxcore_library) { std::map> adapter_map = {}; - ComPtr dxgi_factory = nullptr; - HRESULT hr = CreateDXGIFactory1(IID_PPV_ARGS(&dxgi_factory)); + + HRESULT hr; + + // DXCore is optional as it was missing in older Windows 10 versions. + // It's needed for MDCM devices (NPU's) which are not enumerable via DXGI. + if (dxcore_library.is_valid()) { + // Disambiguate primary function from the templated overload. + using DXCoreCreateAdapterFactoryProc = + decltype(static_cast( + DXCoreCreateAdapterFactory)); + + DXCoreCreateAdapterFactoryProc dxcore_factory_proc = + reinterpret_cast( + dxcore_library.GetFunctionPointer("DXCoreCreateAdapterFactory")); + + if (!dxcore_factory_proc) { + DLOG(ERROR) << "DXCoreCreateAdapterFactory export is missing"; + return adapter_map; + } + + ComPtr dxcore_factory; + hr = dxcore_factory_proc(IID_PPV_ARGS(&dxcore_factory)); + + // DXCore.dll exists in Windows 19H1/19H2, and it exports + // DXCoreCreateAdapterFactory, but it instantiates a different version + // of IDXCoreAdapterFactory (same name, different IID) than the one we + // expect. So it's possible and expected to get E_NOINTERFACE here if + // running DirectML on Windows 19H1/19H2. + if (hr == E_NOINTERFACE) { + // Do nothing as this is expected and ignorable. + } else if (FAILED(hr)) { + DLOG(ERROR) << "DXCoreCreateAdapterFactory failed: " + << logging::SystemErrorCodeToString(hr); + return adapter_map; + } else { + const GUID dx_guids[] = {DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE}; + ComPtr adapter_list; + hr = dxcore_factory->CreateAdapterList( + ARRAYSIZE(dx_guids), dx_guids, IID_PPV_ARGS(&adapter_list)); + + if (FAILED(hr)) { + DLOG(ERROR) << "IDXCoreAdapterFactory::CreateAdapterList failed: " + << logging::SystemErrorCodeToString(hr); + return adapter_map; + } + + // Enumerate all DXCore adapters. + uint32_t adapter_count = adapter_list->GetAdapterCount(); + for (uint32_t adapter_index = 0; adapter_index < adapter_count; ++adapter_index) { + ComPtr dxcore_adapter; + hr = adapter_list->GetAdapter(adapter_index, IID_PPV_ARGS(&dxcore_adapter)); + if (FAILED(hr)) { + DLOG(ERROR) << "IDXCoreAdapter::GetAdapter failed: " + << logging::SystemErrorCodeToString(hr); + return adapter_map; + } + + auto dml_adapter = + base::MakeRefCounted(std::move(dxcore_adapter)); + hr = dml_adapter->Initialize(); + if (FAILED(hr)) { + DLOG(ERROR) << "AdapterDML::Initialize failed: " + << logging::SystemErrorCodeToString(hr); + return adapter_map; + } + adapter_map[dml_adapter->GetAdapterType()] = dml_adapter; + } + + return adapter_map; + } + } // end DXCore enumeration + + ComPtr dxgi_factory; + hr = CreateDXGIFactory1(IID_PPV_ARGS(&dxgi_factory)); if (FAILED(hr)) { DLOG(ERROR) << "Create DXGI factory failed: " << logging::SystemErrorCodeToString(hr); return adapter_map; } - // Eumerate all available adapters, DXGI_ERROR_NOT_FOUND means there are no - // more adapters to enumerate. + // Eumerate all available DXGI adapters. + // DXGI_ERROR_NOT_FOUND means there are no more adapters to enumerate. ComPtr dxgi_adapter; uint32_t adapter_index = 0; while (dxgi_factory->EnumAdapters1(adapter_index++, &dxgi_adapter) != DXGI_ERROR_NOT_FOUND) { - ComPtr dxgi_adapter3 = nullptr; + ComPtr dxgi_adapter3; hr = dxgi_adapter.As(&dxgi_adapter3); if (FAILED(hr)) { DLOG(ERROR) << "Get adapter3 failed: " @@ -39,7 +115,8 @@ std::map> EumerateAdapters() { return adapter_map; } auto adapter = base::MakeRefCounted(std::move(dxgi_adapter3)); - if (FAILED(adapter->Initialize())) { + hr = adapter->Initialize(); + if (FAILED(hr)) { DLOG(ERROR) << "Initialize adapter failed: " << logging::SystemErrorCodeToString(hr); return adapter_map; @@ -60,7 +137,10 @@ void WebnnServiceDMLImpl::Create( WebnnServiceDMLImpl::WebnnServiceDMLImpl( mojo::PendingReceiver receiver) - : receiver_(this, std::move(receiver)), adapter_map_(EumerateAdapters()) {} + : receiver_(this, std::move(receiver)), + dxcore_library_( + base::ScopedNativeLibrary(base::LoadSystemLibrary(L"dxcore.dll"))), + adapter_map_(EnumerateAdapters(dxcore_library_)) {} WebnnServiceDMLImpl::~WebnnServiceDMLImpl() = default; @@ -70,16 +150,22 @@ void WebnnServiceDMLImpl::BindMojoServer( } scoped_refptr WebnnServiceDMLImpl::RequestAdapter( - PowerPreference power_preference) { - AdapterType preferred_type; - switch (power_preference) { - case PowerPreference::kLowPower: - preferred_type = AdapterType::kIntegratedGPU; - break; - case PowerPreference::kDefault: - case PowerPreference::kHighPerformance: - preferred_type = AdapterType::kDiscreteGPU; - break; + DevicePreference device_preference, + PowerPreference power_preference) +{ + AdapterType preferred_type = AdapterType::kUnknown; + if (device_preference == DevicePreference::kNpu) { + preferred_type = AdapterType::kNPU; + } else { + switch (power_preference) { + case PowerPreference::kLowPower: + preferred_type = AdapterType::kIntegratedGPU; + break; + case PowerPreference::kDefault: + case PowerPreference::kHighPerformance: + preferred_type = AdapterType::kDiscreteGPU; + break; + } } auto iter = adapter_map_.find(preferred_type); if (iter != adapter_map_.end()) { @@ -87,11 +173,17 @@ scoped_refptr WebnnServiceDMLImpl::RequestAdapter( } // Select device sequentially if there is no preferred type. + iter = adapter_map_.find(AdapterType::kDiscreteGPU); if (iter != adapter_map_.end()) { return iter->second; } + iter = adapter_map_.find(AdapterType::kNPU); + if (iter != adapter_map_.end()) { + return iter->second; + } + iter = adapter_map_.find(AdapterType::kIntegratedGPU); if (iter != adapter_map_.end()) { return iter->second; diff --git a/content/browser/ml/webnn/dml/webnn_service_dml_impl.h b/content/browser/ml/webnn/dml/webnn_service_dml_impl.h index 62ae8546dc82e9..12080489705d9e 100644 --- a/content/browser/ml/webnn/dml/webnn_service_dml_impl.h +++ b/content/browser/ml/webnn/dml/webnn_service_dml_impl.h @@ -11,11 +11,14 @@ #include "content/browser/ml/webnn/dml/adapter_dml.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver.h" +#include "base/native_library.h" +#include "base/scoped_native_library.h" namespace content::webnn { namespace { +using ml::model_loader::mojom::DevicePreference; using ml::model_loader::mojom::PowerPreference; } @@ -37,10 +40,12 @@ class WebnnServiceDMLImpl : public ml::webnn::mojom::WebnnService { void BindMojoServer( mojo::PendingReceiver receiver) override; - scoped_refptr RequestAdapter(PowerPreference power_preference); + scoped_refptr RequestAdapter(DevicePreference device_preference, + PowerPreference power_preference); private: mojo::Receiver receiver_; + base::ScopedNativeLibrary dxcore_library_; std::map> adapter_map_; }; diff --git a/content/gpu/gpu_child_thread_receiver_bindings.cc b/content/gpu/gpu_child_thread_receiver_bindings.cc index 9d73ffc7236509..db23795924271e 100644 --- a/content/gpu/gpu_child_thread_receiver_bindings.cc +++ b/content/gpu/gpu_child_thread_receiver_bindings.cc @@ -62,9 +62,8 @@ void GpuChildThread::BindServiceInterface( #if BUILDFLAG(ENABLE_MOJO_WEBNN_IN_GPU_PROCESS) if (auto webnn_receiver = receiver.As()) { - scoped_refptr task_runner; - task_runner = base::ThreadPool::CreateSingleThreadTaskRunner( - {base::TaskPriority::USER_BLOCKING}); + scoped_refptr task_runner = + base::SingleThreadTaskRunner::GetCurrentDefault(); task_runner->PostTask( FROM_HERE, base::BindOnce( [](mojo::PendingReceiver diff --git a/third_party/blink/renderer/bindings/generated_in_modules.gni b/third_party/blink/renderer/bindings/generated_in_modules.gni index f1c4e3db3eab8d..98beafbfebed76 100644 --- a/third_party/blink/renderer/bindings/generated_in_modules.gni +++ b/third_party/blink/renderer/bindings/generated_in_modules.gni @@ -669,24 +669,80 @@ generated_dictionary_sources_in_modules = [ "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_midi_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_midi_permission_descriptor.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_midi_permission_descriptor.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_options_internal.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_integer_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_integer_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_options_internal.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_fill_sequence_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_fill_sequence_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_float_parameter_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_float_parameter_options_internal.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_mean_variance_normalization_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_mean_variance_normalization_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options_internal.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_slice_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_slice_options_internal.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_softplus_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_softplus_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options_internal.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_squeeze_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_squeeze_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor_info.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor_info.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_matrix_options.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_matrix_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_multi_cache_query_options.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_multi_cache_query_options.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_navigation_preload_state.cc", @@ -1308,22 +1364,36 @@ generated_enumeration_sources_in_modules = [ "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_auto_pad.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_filter_operand_layout.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_filter_operand_layout.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_filter_operand_layout_internal.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_filter_operand_layout_internal.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_filter_operand_layout.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_filter_operand_layout.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_data_type.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_data_type.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_device_preference.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_device_preference.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_weight_layout.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_weight_layout.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_input_operand_layout.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_input_operand_layout.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_interpolation_mode.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_interpolation_mode.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_weight_layout.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_weight_layout.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_model_format.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_model_format.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_type.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_type.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_padding_mode.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_padding_mode.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_direction.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_recurrent_network_direction.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_rounding_type.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_rounding_type.h", - "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_interpolation_mode.cc", - "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_interpolation_mode.h", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_part.cc", + "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_part.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_notification_action_type.cc", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_notification_action_type.h", "$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_notification_direction.cc", diff --git a/third_party/blink/renderer/modules/ml/ml_context.cc b/third_party/blink/renderer/modules/ml/ml_context.cc index 71da4630b8040b..8781df57d50f4b 100644 --- a/third_party/blink/renderer/modules/ml/ml_context.cc +++ b/third_party/blink/renderer/modules/ml/ml_context.cc @@ -11,6 +11,8 @@ #include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h" #include "third_party/blink/renderer/platform/bindings/exception_state.h" +#pragma optimize("", off) // TODO:::DELETE + namespace blink { MLContext::MLContext(const V8MLDevicePreference device_preference, @@ -23,7 +25,8 @@ MLContext::MLContext(const V8MLDevicePreference device_preference, model_format_(model_format), num_threads_(num_threads), ml_(ml), - webnn_context_(ml->GetExecutionContext()) {} + webnn_context_(ml->GetExecutionContext()) { +} MLContext::~MLContext() = default; @@ -65,11 +68,31 @@ bool MLContext::IsWebnnMojoContextEnabled() const { device_preference_ != V8MLDevicePreference::Enum::kCpu; } +bool MLContext::IsDedicatedHardwareDevice() const { + // If no device preference is set, then use the GPU for high performance or + // NPU for low performance. + if (device_preference_ == V8MLDevicePreference::Enum::kAuto) { + return power_preference_ == V8MLPowerPreference::Enum::kLowPower || + power_preference_ == V8MLPowerPreference::Enum::kHighPerformance; + } + + return device_preference_ == V8MLDevicePreference::Enum::kGpu || + device_preference_ == V8MLDevicePreference::Enum::kNpu; +} + void MLContext::CreateWebnnMojoContext(ScriptPromiseResolver* resolver) { auto options = ml::webnn::mojom::blink::ContextOptions::New(); // TODO(crbug.com/1273291): Set power preference in the context option. + + static_assert(uint32_t(V8MLDevicePreference::Enum::kAuto) == uint32_t(ml::model_loader::mojom::blink::DevicePreference::kAuto)); + static_assert(uint32_t(V8MLDevicePreference::Enum::kCpu) == uint32_t(ml::model_loader::mojom::blink::DevicePreference::kCpu)); + static_assert(uint32_t(V8MLDevicePreference::Enum::kGpu) == uint32_t(ml::model_loader::mojom::blink::DevicePreference::kGpu)); + static_assert(uint32_t(V8MLDevicePreference::Enum::kNpu) == uint32_t(ml::model_loader::mojom::blink::DevicePreference::kNpu)); + options->device_preference = - ml::model_loader::mojom::blink::DevicePreference::kGpu; + static_cast(device_preference_.AsEnum()); + options->power_preference = + static_cast(power_preference_.AsEnum()); ml_->CreateWebnnMojoContext( resolver, std::move(options), WTF::BindOnce(&MLContext::OnWebnnContextCreated, WrapPersistent(this), @@ -81,7 +104,10 @@ void MLContext::CreateWebnnMojoContextSync(ScriptState* script_state, auto options = ml::webnn::mojom::blink::ContextOptions::New(); // TODO(crbug.com/1273291): Set power preference in the context option. options->device_preference = - ml::model_loader::mojom::blink::DevicePreference::kGpu; + static_cast(device_preference_.AsEnum()); + options->power_preference = + static_cast(power_preference_.AsEnum()); + ::mojo::PendingRemote<::ml::webnn::mojom::blink::Context> pending_remote; ml_->CreateWebnnMojoContextSync(std::move(options), &pending_remote, exception_state); diff --git a/third_party/blink/renderer/modules/ml/ml_context.h b/third_party/blink/renderer/modules/ml/ml_context.h index a8defc225a50d8..3215ff17f0edc6 100644 --- a/third_party/blink/renderer/modules/ml/ml_context.h +++ b/third_party/blink/renderer/modules/ml/ml_context.h @@ -51,6 +51,8 @@ class MODULES_EXPORT MLContext : public ScriptWrappable { // process, the runtime enable feature is used to disable the cross process // hardware acceleration by default. bool IsWebnnMojoContextEnabled() const; + // Returns true for GPU and NPU devices (false for CPU). + bool IsDedicatedHardwareDevice() const; // Create WebNN mojo context in server side and await the callback to resolve // the ml context. void CreateWebnnMojoContext(ScriptPromiseResolver* resolver); diff --git a/third_party/blink/renderer/modules/ml/ml_context_options.idl b/third_party/blink/renderer/modules/ml/ml_context_options.idl index a1e9b1cd4e9798..06c5c19adac719 100644 --- a/third_party/blink/renderer/modules/ml/ml_context_options.idl +++ b/third_party/blink/renderer/modules/ml/ml_context_options.idl @@ -11,11 +11,13 @@ enum MLDevicePreference { // Let the backend selects the most suitable device. "auto", + // The backend will use CPU to do model inference. + "cpu", // The backend will use GPU to do model inference. If some operator is not // supported by GPU, it will fall back to CPU. "gpu", - // The backend will use CPU to do model inference. - "cpu" + // The backend will use NPU to do model inference. + "npu", }; enum MLPowerPreference { diff --git a/third_party/blink/renderer/modules/ml/ml_model_loader.cc b/third_party/blink/renderer/modules/ml/ml_model_loader.cc index 40ce9d153cfa35..075dca52dc1601 100644 --- a/third_party/blink/renderer/modules/ml/ml_model_loader.cc +++ b/third_party/blink/renderer/modules/ml/ml_model_loader.cc @@ -126,6 +126,8 @@ DevicePreference ConvertBlinkDevicePreferenceToMojo( return DevicePreference::kCpu; case V8MLDevicePreference::Enum::kGpu: return DevicePreference::kGpu; + case V8MLDevicePreference::Enum::kNpu: + return DevicePreference::kNpu; } } diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc index b3324511a683db..b4f2108739cde0 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph.cc @@ -203,6 +203,9 @@ bool MLGraph::ValidateAndInitializeResourcesInfo( operators_queue.push_back(operand->Operator()); } + // An input MLOperand may be used by more than one MLOperators. This set + // ensures an input MLOperand won't be validated multiple times. + HeapHashSet> visited_input_operands; while (operators_queue.size() > 0) { // If the queue is not empty, dequeue an operator from the queue. const auto current_operator = operators_queue.TakeFirst(); @@ -220,6 +223,12 @@ bool MLGraph::ValidateAndInitializeResourcesInfo( } break; case MLOperand::OperandKind::kInput: + // If the operand has been validated, it doesn't need to be verified + // multiple times. + if (visited_input_operands.Contains(operand)) { + continue; + } + visited_input_operands.insert(operand); // If the operand is an input operand, validate whether its name is // unique. if (input_resources_info_.Contains(operand->Name())) { diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc index 7ccdd416cdb86e..7b3eb60b4e00ac 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc @@ -5,15 +5,52 @@ #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h" #include +#include +#include #include "base/numerics/checked_math.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" + +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_integer_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_fill_sequence_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_graph.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_mean_variance_normalization_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_type.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operator.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options_internal.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_slice_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_softplus_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_squeeze_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_matrix_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_float_parameter_options_internal.h" + #include "third_party/blink/renderer/core/dom/dom_exception.h" #include "third_party/blink/renderer/core/inspector/console_message.h" #include "third_party/blink/renderer/modules/ml/ml.h" @@ -24,6 +61,10 @@ #include "third_party/blink/renderer/modules/ml/webnn/mojo_graph.h" #include "third_party/blink/renderer/platform/bindings/exception_state.h" #include "third_party/blink/renderer/platform/heap/collection_support/heap_deque.h" +#include "third_party/blink/renderer/bindings/core/v8/v8_union_unsignedlong_unsignedlongsequence.h" + +// TODO::: +#pragma optimize("", off) #if BUILDFLAG(BUILD_WEBNN_WITH_XNNPACK) #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.h" @@ -35,6 +76,24 @@ namespace { MLGraphBuilder::BackendForTesting* g_backend_for_testing = nullptr; +namespace MLOperandTypeMask { +// Use as bitmask, and so avoid enum class which inhibits boolean operations. +enum Enum : uint32_t { + kFloat32 = 1 << uint32_t(V8MLOperandType::Enum::kFloat32), + kFloat16 = 1 << uint32_t(V8MLOperandType::Enum::kFloat16), + kInt8 = 1 << uint32_t(V8MLOperandType::Enum::kInt8), + kUint8 = 1 << uint32_t(V8MLOperandType::Enum::kUint8), + kInt32 = 1 << uint32_t(V8MLOperandType::Enum::kInt32), + kUint32 = 1 << uint32_t(V8MLOperandType::Enum::kUint32), + kInt64 = 1 << uint32_t(V8MLOperandType::Enum::kInt64), + kUint64 = 1 << uint32_t(V8MLOperandType::Enum::kUint64), +}; +} // namespace MLOperandTypeMask + +bool IsAllowedType(V8MLOperandType::Enum operand_type, MLOperandTypeMask::Enum operand_type_mask) { + return (1 << uint32_t(operand_type)) & operand_type_mask; +} + bool IsFloatingPointType(V8MLOperandType::Enum operand_type) { switch (operand_type) { case V8MLOperandType::Enum::kFloat32: @@ -44,8 +103,158 @@ bool IsFloatingPointType(V8MLOperandType::Enum operand_type) { case V8MLOperandType::Enum::kUint32: case V8MLOperandType::Enum::kInt8: case V8MLOperandType::Enum::kUint8: + case V8MLOperandType::Enum::kInt64: + case V8MLOperandType::Enum::kUint64: + return false; + } +} + +bool IsBooleanType(V8MLOperandType::Enum operand_type) { + // Boolean types are unsigned 8-bit values. + switch (operand_type) { + case V8MLOperandType::Enum::kFloat32: + case V8MLOperandType::Enum::kFloat16: + case V8MLOperandType::Enum::kInt32: + case V8MLOperandType::Enum::kUint32: + case V8MLOperandType::Enum::kInt8: + case V8MLOperandType::Enum::kInt64: + case V8MLOperandType::Enum::kUint64: + return false; + case V8MLOperandType::Enum::kUint8: + return true; + } +} + +bool IsIndexType(V8MLOperandType::Enum operand_type) { + // Index types are integers, signed or unsigned. + switch (operand_type) { + case V8MLOperandType::Enum::kFloat32: + case V8MLOperandType::Enum::kFloat16: + case V8MLOperandType::Enum::kInt8: + case V8MLOperandType::Enum::kUint8: + return false; + case V8MLOperandType::Enum::kInt32: + case V8MLOperandType::Enum::kUint32: + case V8MLOperandType::Enum::kInt64: + case V8MLOperandType::Enum::kUint64: + return true; + } +} + +bool ValidateValueMatchesExpected(uint32_t actual_value, + uint32_t expected_value, + const char* value_name, + ExceptionState& exception_state) { + if (actual_value != expected_value) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The %s (%u) does not match the expected value (%u).", + value_name, actual_value, expected_value)); + return false; + } + return true; +} + +bool ValidateValueMatchesExpected(size_t actual_value, + size_t expected_value, + const char* value_name, + ExceptionState& exception_state) { + if (actual_value != expected_value) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The %s (%zu) does not match the expected value (%zu).", + value_name, actual_value, expected_value)); + return false; + } + return true; +} + +bool ValidateAxis(uint32_t axis, + uint32_t dimension_count, + const char* operator_name, + ExceptionState& exception_state) { + if (axis >= dimension_count) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "The %s axis (%u) must be within the dimension count (%u).", + operator_name, + axis, dimension_count)); + return false; + } + return true; +} + +bool ValidateAxes(base::span axes, + uint32_t dimension_count, + const char* operator_name, + ExceptionState& exception_state) { + Vector seen_axes(dimension_count); + + for (auto axis : axes) + { + if (axis >= dimension_count) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "The %s axis (%u) must be less than the dimension count (%u).", + operator_name, + axis, dimension_count)); + return false; + } + if (seen_axes[axis]) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "Each %s axis (%u) must only occur once.", + operator_name, axis)); + return false; + } + seen_axes[axis] = true; + } + return true; +} + +// Generates a 32-bit mask, validating all axes fit within 32 dimensions. +bool ValidateAxesMask(base::span axes, + const char* operator_name, + ExceptionState& exception_state, + /*out*/ uint32_t& axes_mask) { + uint32_t current_mask = 0x00000000; + axes_mask = current_mask; + + // Use up to 32 bits. Although the WebNN spec does not give a limit, + // leaving it up to the implementations, 32 is a good practical limit, + // and other libraries often support less. + const uint32_t maximum_rank = 32; + + for (auto axis : axes) + { + if (axis >= maximum_rank) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "%s axis (%u) is beyond the maximum (%u).", + operator_name, + axis, maximum_rank)); + return false; + } + uint32_t single_axis_mask = 1 << axis; + if (single_axis_mask & current_mask) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("%s axis (%u) appears more than once.", operator_name, + axis)); return false; + } + current_mask |= single_axis_mask; } + axes_mask = current_mask; + return true; } bool ValidateClampOptions(const MLClampOptions* options, @@ -66,24 +275,112 @@ bool ValidateClampOptions(const MLClampOptions* options, return true; } +MLFloatParameterOptionsInternal* ConvertEluOptions( + const MLEluOptions* options, + ExceptionState& exception_state) { + + MLFloatParameterOptionsInternal* internal_options = MLFloatParameterOptionsInternal::Create(); + internal_options->setFirstParameter(options->alpha()); + return internal_options; +} + +MLFloatParameterOptionsInternal* ConvertHardSigmoidOptions( + const MLHardSigmoidOptions* options, + ExceptionState& exception_state) { + + MLFloatParameterOptionsInternal* internal_options = MLFloatParameterOptionsInternal::Create(); + internal_options->setFirstParameter(options->alpha()); + internal_options->setSecondParameter(options->beta()); + return internal_options; +} + +MLFloatParameterOptionsInternal* ConvertLeakyReluOptions( + const MLLeakyReluOptions* options, + ExceptionState& exception_state) { + + MLFloatParameterOptionsInternal* internal_options = MLFloatParameterOptionsInternal::Create(); + internal_options->setFirstParameter(options->alpha()); + return internal_options; +} + +MLFloatParameterOptionsInternal* ConvertLinearOptions( + const MLLinearOptions* options, + ExceptionState& exception_state) { + + MLFloatParameterOptionsInternal* internal_options = MLFloatParameterOptionsInternal::Create(); + internal_options->setFirstParameter(options->alpha()); + internal_options->setSecondParameter(options->beta()); + return internal_options; +} + +MLFloatParameterOptionsInternal* ConvertSoftplusOptions( + const MLSoftplusOptions* options, + ExceptionState& exception_state) { + + MLFloatParameterOptionsInternal* internal_options = MLFloatParameterOptionsInternal::Create(); + internal_options->setFirstParameter(options->steepness()); + return internal_options; +} + +// Computes the number of elements given the dimensions. +// Note this expects dimensions that are already known and validated +// and thus cannot overflow, rather than untrusted parameters like +// reshape's new shape before validation. +uint32_t ComputeElementCount(base::span dimensions) +{ + return std::accumulate(dimensions.begin(), dimensions.end(), 1u, std::multiplies{}); +} + +// Increases the rank to a minimum count by padding with leading ones. +Vector ExpandDimensions( + const base::span original_dimensions, + wtf_size_t minimum_rank) { + + wtf_size_t old_rank = static_cast(original_dimensions.size()); + wtf_size_t new_rank = std::max(minimum_rank, static_cast(old_rank)); + wtf_size_t leading_filler_count = new_rank - old_rank; + + Vector expanded_dimensions(new_rank, 1u); + std::copy(original_dimensions.begin(), original_dimensions.end(), + expanded_dimensions.begin() + leading_filler_count); + return expanded_dimensions; +} + // Broadcast the input shapes and return the output shape. // If bidirectional is true, its behavior follows the numpy-broadcasting-rule: // https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules. // Otherwise, it unidirectionally broadcasts the lhs to the rhs. +// The ignorable tail count is useful for cases like MatMul, where you want +// to ignore the trailing tail of dimensions and only broadcast the leading +// ones, because the trailing part (returned as 0's) will be filled in later. absl::optional> BroadcastShapes( - const Vector& dims_lhs, - const Vector& dims_rhs, - bool bidirectional = true) { + base::span dims_lhs, + base::span dims_rhs, + bool bidirectional = true, + wtf_size_t ignorable_tail_count = 0 +) { // If bidirectional is true, the rank of the output shape is the maximum // rank of the input shapes. Otherwise it is as the same as the rhs' rank. - auto rank_lhs = dims_lhs.size(), rank_rhs = dims_rhs.size(); + auto rank_lhs = static_cast(dims_lhs.size()); + auto rank_rhs = static_cast(dims_rhs.size()); auto rank_output = bidirectional ? std::max(rank_lhs, rank_rhs) : rank_rhs; Vector dims_output(rank_output); - for (wtf_size_t i = 0; i < rank_output; ++i) { + + // Note the loop effectively works backwards from the end of the dimensions + // array (the counter is forward, but accesses are relative the end). + for (wtf_size_t i = ignorable_tail_count; i < rank_output; ++i) { + auto dim_lhs = i < rank_lhs ? dims_lhs[rank_lhs - i - 1] : 1; +#if 0 // TODO:::DELETE - Broadcasting needs to work correctly with zero dimensions. + // Ultimately this should handle empty tensors in the lower layer via nop. + // A DCHECK is not appropriate here. DCHECK_GT(dim_lhs, uint32_t(0)); +#endif auto dim_rhs = i < rank_rhs ? dims_rhs[rank_rhs - i - 1] : 1; +#if 0 // TODO:::DELETE - Broadcasting needs to work correctly with zero dimensions. DCHECK_GT(dim_rhs, uint32_t(0)); +#endif + // If bidirectional is true, two dimensions are compatible when they are // equal, or one of them is 1. Otherwise, two dimensions are compatible // when they are equal, or the lhs dimension is 1. @@ -94,6 +391,7 @@ absl::optional> BroadcastShapes( } else if (dim_lhs != dim_rhs && dim_lhs != 1) { return absl::nullopt; } + // If bidirectional is true, for each dimension of the output tensor, its // size is the maximum size along that dimension of the input shapes. // Otherwise, its size is the same as the rhs. @@ -103,10 +401,75 @@ absl::optional> BroadcastShapes( return dims_output; } -MLOperand* BuildElementWiseBinary(MLGraphBuilder* builder, +bool ValidateUnidirectionalBroadcastability( + base::span validating_dimensions, + base::span target_dimensions, + const char* operator_name, + const char* validating_tensor_name, + const char* target_tensor_name, + ExceptionState& exception_state) { + absl::optional> broadcasted_dimensions = + BroadcastShapes(validating_dimensions, target_dimensions, false); + if (!broadcasted_dimensions) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The %s tensor %s is not a broadcastable shape to %s.", + operator_name, validating_tensor_name, + target_tensor_name)); + return false; + } + return true; +} + +MLOperand* BuildUnaryOperator(MLGraphBuilder* builder, + MLOperator::OperatorKind kind, + const MLOperand* input, + ExceptionState& exception_state) { + String error_message; + auto* ml_operator = MakeGarbageCollected(builder, kind); + + Vector output_dimensions = input->Dimensions(); + auto* output = MLOperand::ValidateAndCreateOutput( + builder, input->Type(), std::move(output_dimensions), ml_operator, /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + + ml_operator->Connect({input}, {output}); + return output; +} + +MLOperand* BuildUnaryOperator( + MLGraphBuilder* builder, + MLOperator::OperatorKind kind, + const MLOperand* input, + Vector output_dimensions, + V8MLOperandType::Enum output_data_type, + const bindings::DictionaryBase* options, + ExceptionState& exception_state) { + String error_message; + auto* ml_operator = MakeGarbageCollected(builder, kind, options); + + auto* output = MLOperand::ValidateAndCreateOutput( + builder, output_data_type, std::move(output_dimensions), ml_operator, + /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + + ml_operator->Connect({input}, {output}); + return output; +} + +MLOperand* BuildElementwiseBinary(MLGraphBuilder* builder, MLOperator::OperatorKind kind, const MLOperand* a, const MLOperand* b, + V8MLOperandType::Enum output_data_type, ExceptionState& exception_state) { if (a->Type() != b->Type()) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, @@ -121,10 +484,11 @@ MLOperand* BuildElementWiseBinary(MLGraphBuilder* builder, "The input shapes are not broadcastable."); return nullptr; } + auto* binary = MakeGarbageCollected(builder, kind); String error_message; auto* output = MLOperand::ValidateAndCreateOutput( - builder, a->Type(), dims_output.value(), binary, error_message); + builder, output_data_type, std::move(dims_output.value()), binary, error_message); if (!output) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, error_message); @@ -134,6 +498,101 @@ MLOperand* BuildElementWiseBinary(MLGraphBuilder* builder, return output; } +MLOperand* BuildArgMinMax(MLGraphBuilder* graph_builder, + MLOperator::OperatorKind operator_kind, + const char* operator_name, + const MLOperand* input, + const MLArgMinMaxOptions* options, + ExceptionState& exception_state) { + // Validate axis. + uint32_t axis = options->axis(); + auto& input_dimensions = input->Dimensions(); + if (!ValidateAxis(axis, + input_dimensions.size(), + operator_name, + exception_state)) { + return nullptr; + } + + // Determine output size, eliminating the active axis or keeping it with size 1. + Vector output_dimensions = input_dimensions; + if (options->keepDimensions()) + { + output_dimensions[axis] = 1; + } + else + { + output_dimensions.EraseAt(axis); + } + + return BuildUnaryOperator(graph_builder, operator_kind, input, + output_dimensions, V8MLOperandType::Enum::kInt64, options, + exception_state); +} + +MLOperand* BuildReductionOperator(MLGraphBuilder* graph_builder, + MLOperator::OperatorKind operator_kind, + const char* operator_name, + const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + Vector axes; + uint32_t axes_mask = 0xFFFFFFFF; // Remove all axes by default, if none passed. + + // Verify the axes are within the input rank and not duplicated. + if (options->hasAxes()) + { + axes = options->axes(); + if (!ValidateAxes(axes, input_rank, operator_name, exception_state)) { + return nullptr; + } + if (!ValidateAxesMask(options->axes(), + operator_name, + exception_state, + /*out*/ axes_mask)) { + return nullptr; + } + } + else // Reduce all dimensions if permutations are missing. + { + axes.resize(input_rank); + std::iota(axes.begin(), axes.end(), 0u); + // axes_mask already 0xFFFFFFFF. + } + + // Set dimension to 1 that are reduced. + // or erase them entirely if MLReduceOptions::keepDimensions = false. + Vector output_dimensions = input_dimensions; + wtf_size_t output_rank = input_rank; + bool keep_dimensions = options->keepDimensions(); + for (wtf_size_t i = 0; i < output_rank; /*increment in loop*/) + { + wtf_size_t advance_count = 1; + if (axes_mask & (1 << i)) { + if (keep_dimensions) { + output_dimensions[i] = 1u; // Reduce dimension. + } + else { + output_dimensions.EraseAt(i); // Remove reduced dimension. + advance_count = 0; // Stay at the current index. + --output_rank; + } + } + i += advance_count; + } + + // Pass the normalized options onward, simplifying the lower level's job. + MLReduceOptions* normalized_options = MLReduceOptions::Create(); + normalized_options->setAxes(axes); + + return BuildUnaryOperator(graph_builder, operator_kind, input, + output_dimensions, input->Type(), normalized_options, + exception_state); +} + struct PaddingSizes { uint32_t begin; uint32_t end; @@ -192,7 +651,17 @@ absl::optional CalculateConv2dOutputSize( const uint32_t ending_padding, const uint32_t stride, const uint32_t dilation, + const uint32_t output_padding, + const bool is_backward_direction, String& error_message) { + // Forward direction calculation: + // output size = 1 + (input size - (filter size - 1) * dilation - 1 + + // beginning padding + ending padding) / stride + // + // Backward direction calculation: + // output size = (input size - 1) * stride + (filter size - 1) * dilation + + // 1 - beginning padding - ending padding + output padding + // // Calculate the dilated filter sizes. auto checked_effective_filter_size = (base::MakeCheckedNum(filter_size) - 1) * dilation + 1; @@ -206,11 +675,19 @@ absl::optional CalculateConv2dOutputSize( // https://en.wikipedia.org/wiki/Double-precision_floating-point_format#Precision_limitations_on_integer_values // The max value of checked_output_size should be 3 * UINT_MAX + 1, // which is smaller than the max safe integer value for double type. - auto checked_output_size = - (base::MakeCheckedNum(input_size) - - checked_effective_filter_size + beginning_padding + ending_padding) / - stride + - 1; + base::CheckedNumeric checked_output_size; + if (is_backward_direction) { + checked_output_size = + (base::MakeCheckedNum(input_size) - 1) * stride + + checked_effective_filter_size - beginning_padding - + ending_padding + output_padding; + } else { + checked_output_size = + (base::MakeCheckedNum(input_size) - + checked_effective_filter_size + beginning_padding + ending_padding) / + stride + + 1; // Note output_padding not used. + } if (checked_output_size.ValueOrDie() < 0) { error_message = "The input size is too small to fill the window."; @@ -240,26 +717,28 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( const uint32_t input_width, const uint32_t filter_height, const uint32_t filter_width, - const Vector& padding, - const Vector& strides, - const Vector& dilations, + const base::span padding, // 4 elements + const base::span strides, // 2 elements + const base::span dilations, // 2 elements + const base::span output_padding, // 2 elements const V8MLAutoPad auto_pad, + const bool is_backward_direction, ExceptionState& exception_state) { // Validate padding and get its values. - if (padding.size() != 4) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The length of padding should be 4."); + + if (!ValidateValueMatchesExpected(padding.size(), size_t(4), + "padding length", exception_state)) { return absl::nullopt; } + uint32_t padding_beginning_height = padding[0]; uint32_t padding_ending_height = padding[1]; uint32_t padding_beginning_width = padding[2]; uint32_t padding_ending_width = padding[3]; // Validate strides and get its values. - if (strides.size() != 2) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The length of strides should be 2."); + if (!ValidateValueMatchesExpected(strides.size(), size_t(2), + "strides length", exception_state)) { return absl::nullopt; } if (std::any_of(strides.begin(), strides.end(), @@ -271,10 +750,9 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( const uint32_t stride_height = strides[0]; const uint32_t stride_width = strides[1]; - // Validate dilations and get its values. - if (dilations.size() != 2) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The length of dilations should be 2."); + // Validate dilations, and get its values. + if (!ValidateValueMatchesExpected(dilations.size(), size_t(2), + "dilations length", exception_state)) { return absl::nullopt; } if (std::any_of(dilations.begin(), dilations.end(), @@ -291,6 +769,7 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( // options.padding array are ignored and the explicit padding values need to // be calculated. if (auto_pad != V8MLAutoPad::Enum::kExplicit) { + // Compute vertical padding before and after. auto padding_sizes_height = MLGraphBuilder::CalculatePaddingForAutoPad( auto_pad.AsEnum(), input_height, filter_height, stride_height, dilation_height); @@ -303,6 +782,8 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( } padding_beginning_height = padding_sizes_height.value().begin; padding_ending_height = padding_sizes_height.value().end; + + // Compute horizontal padding before and after. auto padding_sizes_width = MLGraphBuilder::CalculatePaddingForAutoPad( auto_pad.AsEnum(), input_width, filter_width, stride_width, dilation_width); @@ -317,10 +798,23 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( padding_ending_width = padding_sizes_width.value().end; } + uint32_t output_padding_height = 0; + uint32_t output_padding_width = 0; + if (!output_padding.empty()) { + if (!ValidateValueMatchesExpected(output_padding.size(), size_t(2), + "output padding length", + exception_state)) { + return absl::nullopt; + } + output_padding_height = output_padding[0]; + output_padding_width = output_padding[1]; + } + String error_message; auto float_output_height = CalculateConv2dOutputSize( input_height, filter_height, padding_beginning_height, - padding_ending_height, stride_height, dilation_height, error_message); + padding_ending_height, stride_height, dilation_height, + output_padding_height, is_backward_direction, error_message); if (!float_output_height) { exception_state.ThrowDOMException( DOMExceptionCode::kDataError, @@ -330,7 +824,8 @@ absl::optional ValidateAndCalculateConv2dOutputSizes( auto float_output_width = CalculateConv2dOutputSize( input_width, filter_width, padding_beginning_width, padding_ending_width, - stride_width, dilation_width, error_message); + stride_width, dilation_width, output_padding_width, is_backward_direction, + error_message); if (!float_output_width) { exception_state.ThrowDOMException( DOMExceptionCode::kDataError, @@ -412,7 +907,9 @@ MLOperand* BuildPool2d(MLGraphBuilder* builder, // If strides is not present, the values are assumed to be [1,1]. options->getStridesOr({1, 1}), // If dilations is not present, the values are assumed to be [1, 1]. - options->getDilationsOr({1, 1}), options->autoPad(), exception_state); + options->getDilationsOr({1, 1}), {}, options->autoPad(), false, + exception_state); + if (!output_sizes) { return nullptr; } @@ -509,43 +1006,272 @@ MLOperand* BuildPool2d(MLGraphBuilder* builder, return output; } -} // namespace +MLOperand* BuildConv2d(MLGraphBuilder* builder, + MLOperator::OperatorKind operator_kind, + const MLOperand* input, + const MLOperand* filter, + const MLOperand* input_zero_point, + const MLOperand* filter_zero_point, + const MLConvOptionsInternal* options, + ExceptionState& exception_state) { + bool is_backward_direction = + (operator_kind == MLOperator::OperatorKind::kConvTranspose2d); -// static -MLGraphBuilder* MLGraphBuilder::Create(MLContext* context) { - return MakeGarbageCollected(context); -} + // Validate input operand and set its sizes. + const auto input_shape = input->Dimensions(); + if (!ValidateValueMatchesExpected(input_shape.size(), 4u, + "input tensor rank", exception_state)) { + return nullptr; + } -MLGraphBuilder::MLGraphBuilder(MLContext* context) : ml_context_(context) {} + // The input layout option specifies the layout format of the input tensor. + uint32_t input_batches, input_channels, input_height, input_width; + switch (options->inputLayout().AsEnum()) { + case V8MLInputOperandLayout::Enum::kNchw: + // "nchw": [batches, input_channels, height, width] + input_batches = input_shape[0]; + input_channels = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + break; + case V8MLInputOperandLayout::Enum::kNhwc: + // "nhwc": [batches, height, width, input_channels] + input_batches = input_shape[0]; + input_height = input_shape[1]; + input_width = input_shape[2]; + input_channels = input_shape[3]; + break; + } -MLGraphBuilder::~MLGraphBuilder() = default; + // Validate filter operand and set its sizes. + if (filter->Type() != input->Type()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The filter type doesn't match the input type."); + return nullptr; + } -void MLGraphBuilder::Trace(Visitor* visitor) const { - visitor->Trace(ml_context_); - ScriptWrappable::Trace(visitor); -} + const auto filter_shape = filter->Dimensions(); + if (!ValidateValueMatchesExpected(filter_shape.size(), 4u, + "filter tensor rank", exception_state)) { + return nullptr; + } -MLContext* MLGraphBuilder::GetContext() const { - return ml_context_; -} + // Remap filter_shape given V8MLConvFilterOperandLayoutInternal::Enum::*. + + // clang-format off + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::kEnumSize) == 6); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kOihw) == 0); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kIohw) == 1); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kHwoi) == 2); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kHwio) == 3); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kOhwi) == 4); + static_assert(uint32_t(V8MLConvFilterOperandLayoutInternal::Enum::kIhwo) == 5); + + constexpr uint8_t filter_indices_mapping[][4] = { + // O I H W + /* kOihw */ {0, 1, 2, 3}, + /* kIohw */ {1, 0, 2, 3}, + /* kHwoi */ {2, 3, 0, 1}, + /* kHwio */ {3, 2, 0, 1}, + /* kOhwi */ {0, 3, 1, 2}, + /* kIhwo */ {3, 0, 1, 2}, + }; + // clang-format on + + auto filter_layout_index = static_cast(options->filterLayout().AsEnum()); + auto* active_filter_indices_mapping = filter_indices_mapping[filter_layout_index]; + uint32_t filter_output_channels = filter_shape[active_filter_indices_mapping[0]]; + uint32_t filter_input_channels = filter_shape[active_filter_indices_mapping[1]]; + uint32_t filter_height = filter_shape[active_filter_indices_mapping[2]]; + uint32_t filter_width = filter_shape[active_filter_indices_mapping[3]]; + + // Validate group count, and compute the output channel count. + uint32_t group_count = options->groups(); + if (group_count == 0) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + "The groups should be greater than 0."); + return nullptr; + } -// static -absl::optional -MLGraphBuilder::CalculatePaddingForAutoPad(V8MLAutoPad::Enum auto_pad, - const uint32_t input_size, - const uint32_t filter_size, - const uint32_t stride, - const uint32_t dilation) { - auto checked_output_size = - (base::MakeCheckedNum(input_size) + stride - 1) / stride; - auto checked_dilated_filter_size = - (base::MakeCheckedNum(filter_size) - 1) * dilation + 1; - auto checked_needed_input_size = - (checked_output_size - 1) * stride + checked_dilated_filter_size; - if (!checked_needed_input_size.IsValid()) { - return absl::nullopt; + uint32_t output_channels; + if (is_backward_direction) { + if (filter_output_channels > + std::numeric_limits::max() / group_count) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "The group count (%u) times filter output channels (%u) is too " + "large and overflowed.", + group_count, filter_output_channels)); + return nullptr; + } + output_channels = filter_output_channels * group_count; + } else // forward direction + { + if (input_channels % group_count != 0 || + filter_input_channels != input_channels / group_count) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The groups must evenly divide the input " + "channels to filter input channels."); + return nullptr; + } + output_channels = filter_output_channels; } - auto checked_total_padding = + + // Validate bias operand if it is present. + if (options->hasBias()) { + const auto bias_shape = options->bias()->Dimensions(); + if (bias_shape.size() != 1) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + "The bias should be a 1-D tensor."); + return nullptr; + } + if (!ValidateValueMatchesExpected(bias_shape[0], output_channels, + "bias shape", exception_state)) { + return nullptr; + } + if (options->bias()->Type() != input->Type()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The bias type doesn't match input type."); + return nullptr; + } + } + + // Validate and compute output sizes. + absl::optional output_sizes; + if (options->hasOutputSizes()) { + if (!ValidateValueMatchesExpected(options->outputSizes().size(), 2u, + "outputSizes length", exception_state)) { + return nullptr; + } + + const auto& explicit_output_sizes = options->outputSizes(); + output_sizes = FloatSize2D{double(explicit_output_sizes[0]), + double(explicit_output_sizes[1])}; + } else { + if (options->hasPadding() && + !ValidateValueMatchesExpected(options->padding().size(), 4u, + "padding length", exception_state)) { + return nullptr; + } + if (options->hasStrides() && + !ValidateValueMatchesExpected(options->strides().size(), 2u, + "strides length", exception_state)) { + return nullptr; + } + if (options->hasDilations() && + !ValidateValueMatchesExpected(options->dilations().size(), 2u, + "dilations length", exception_state)) { + return nullptr; + } + + output_sizes = ValidateAndCalculateConv2dOutputSizes( + input_height, input_width, filter_height, filter_width, + // If padding is not present, the values are assumed to be [0,0,0,0]. + options->getPaddingOr({0, 0, 0, 0}), + // If strides is not present, the values are assumed to be [1,1]. + options->getStridesOr({1, 1}), + // If dilations is not present, the values are assumed to be [1, 1]. + options->getDilationsOr({1, 1}), options->getOutputPaddingOr({0, 0}), + options->autoPad(), is_backward_direction, exception_state); + } + if (!output_sizes) { + return nullptr; + } + + const uint32_t output_height = + base::ClampFloor(output_sizes.value().height); + const uint32_t output_width = + base::ClampFloor(output_sizes.value().width); + + // The input layout option specifies the layout format of the output tensor. + Vector output_shape; + switch (options->inputLayout().AsEnum()) { + case V8MLInputOperandLayout::Enum::kNchw: + // "nchw": [batches, output_channels, height, width] + output_shape = {input_batches, output_channels, output_height, + output_width}; + break; + case V8MLInputOperandLayout::Enum::kNhwc: + // "nhwc": [batches, height, width, output_channels] + output_shape = {input_batches, output_height, output_width, + output_channels}; + break; + } + + // Create conv2d operator and its output operand. Connect the conv2d + // operator to its input and output operands. + auto* conv2d = MakeGarbageCollected( + builder, operator_kind, options); + + // TODO::: Verify input_zero_point and filter_zero_point shape. + + HeapVector> inputs = {input, filter}; + if (operator_kind == MLOperator::OperatorKind::kConv2dInteger) { + inputs.push_back(input_zero_point); + inputs.push_back(filter_zero_point); + } + if (options->hasBias()) { + inputs.push_back(options->bias()); + } + + V8MLOperandType::Enum output_data_type = + (operator_kind == MLOperator::OperatorKind::kConv2dInteger) + ? V8MLOperandType::Enum::kInt32 + : input->Type(); + + String error_message; + auto* output = MLOperand::ValidateAndCreateOutput( + builder, output_data_type, std::move(output_shape), conv2d, error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + conv2d->Connect(std::move(inputs), {output}); + return output; +} + +} // namespace + +// static +MLGraphBuilder* MLGraphBuilder::Create(MLContext* context) { + return MakeGarbageCollected(context); +} + +MLGraphBuilder::MLGraphBuilder(MLContext* context) : ml_context_(context) {} + +MLGraphBuilder::~MLGraphBuilder() = default; + +void MLGraphBuilder::Trace(Visitor* visitor) const { + visitor->Trace(ml_context_); + ScriptWrappable::Trace(visitor); +} + +MLContext* MLGraphBuilder::GetContext() const { + return ml_context_; +} + +// static +absl::optional +MLGraphBuilder::CalculatePaddingForAutoPad(V8MLAutoPad::Enum auto_pad, + const uint32_t input_size, + const uint32_t filter_size, + const uint32_t stride, + const uint32_t dilation) { + auto checked_output_size = + (base::MakeCheckedNum(input_size) + stride - 1) / stride; + auto checked_dilated_filter_size = + (base::MakeCheckedNum(filter_size) - 1) * dilation + 1; + auto checked_needed_input_size = + (checked_output_size - 1) * stride + checked_dilated_filter_size; + if (!checked_needed_input_size.IsValid()) { + return absl::nullopt; + } + auto checked_total_padding = checked_needed_input_size.ValueOrDie() > input_size ? checked_needed_input_size - input_size : base::MakeCheckedNum(0); @@ -641,166 +1367,125 @@ MLOperand* MLGraphBuilder::conv2d(const MLOperand* input, const MLOperand* filter, const MLConv2dOptions* options, ExceptionState& exception_state) { - // Validate input operand and set its sizes. - const auto input_shape = input->Dimensions(); - if (input_shape.size() != 4) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The input should be a 4-D tensor."); - return nullptr; + // Unify the two convolutions which only differ in direction + // (forward vs backward) with a common internal representation + // that the backends more easily read. + MLConvOptionsInternal* internal_options = MLConvOptionsInternal::Create(); + internal_options->setAutoPad(options->autoPad()); + internal_options->setGroups(options->groups()); + internal_options->setInputLayout(options->inputLayout()); + + if (options->hasPadding()) + { + internal_options->setPadding(options->padding()); } - // The input layout option specifies the layout format of the input tensor. - uint32_t input_batches, input_channels, input_height, input_width; - switch (options->inputLayout().AsEnum()) { - case V8MLInputOperandLayout::Enum::kNchw: - // "nchw": [batches, input_channels, height, width] - input_batches = input_shape[0]; - input_channels = input_shape[1]; - input_height = input_shape[2]; - input_width = input_shape[3]; - break; - case V8MLInputOperandLayout::Enum::kNhwc: - // "nhwc": [batches, height, width, input_channels] - input_batches = input_shape[0]; - input_height = input_shape[1]; - input_width = input_shape[2]; - input_channels = input_shape[3]; - break; + if (options->hasStrides()) + { + internal_options->setStrides(options->strides()); } - - // Validate filter operand and set its sizes. - if (filter->Type() != input->Type()) { - exception_state.ThrowDOMException( - DOMExceptionCode::kDataError, - "The filter type doesn't match the input type."); - return nullptr; + if (options->hasDilations()) + { + internal_options->setDilations(options->dilations()); } - const auto filter_shape = filter->Dimensions(); - if (filter_shape.size() != 4) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The filter should be a 4-D tensor."); - return nullptr; + if (options->hasBias()) + { + internal_options->setBias(options->bias()); } - // The filter layout specifies the filter layout format. - uint32_t filter_height, filter_width, output_channels, filter_input_channels; + if (options->hasActivation()) + { + internal_options->setActivation(options->activation()); + } + + V8MLConvFilterOperandLayoutInternal::Enum filter_layout; + static_assert(V8MLConv2dFilterOperandLayout::kEnumSize == 4); switch (options->filterLayout().AsEnum()) { + default: + case V8MLConv2dFilterOperandLayout::Enum::kOihw: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kOihw; + break; case V8MLConv2dFilterOperandLayout::Enum::kHwio: - // "hwio": [height, width, input_channels/groups, output_channels] - filter_height = filter_shape[0]; - filter_width = filter_shape[1]; - filter_input_channels = filter_shape[2]; - output_channels = filter_shape[3]; + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kHwio; break; case V8MLConv2dFilterOperandLayout::Enum::kOhwi: - // "ohwi": [output_channels, height, width, input_channels/groups] - output_channels = filter_shape[0]; - filter_height = filter_shape[1]; - filter_width = filter_shape[2]; - filter_input_channels = filter_shape[3]; + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kOhwi; break; case V8MLConv2dFilterOperandLayout::Enum::kIhwo: - // "ihwo": [input_channels/groups, height, width, output_channels] - filter_input_channels = filter_shape[0]; - filter_height = filter_shape[1]; - filter_width = filter_shape[2]; - output_channels = filter_shape[3]; - break; - case V8MLConv2dFilterOperandLayout::Enum::kOihw: - // "oihw": [output_channels, input_channels/groups, height, width] - output_channels = filter_shape[0]; - filter_input_channels = filter_shape[1]; - filter_height = filter_shape[2]; - filter_width = filter_shape[3]; + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kIhwo; break; } - // Validate bias operand if it is present. - if (options->hasBias()) { - const auto bias_shape = options->bias()->Dimensions(); - if (bias_shape.size() != 1) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The bias should be a 1-D tensor."); - return nullptr; - } - if (bias_shape[0] != output_channels) { - exception_state.ThrowDOMException( - DOMExceptionCode::kDataError, - String::Format("The bias shape should be [%u].", output_channels)); - return nullptr; - } - if (options->bias()->Type() != input->Type()) { - exception_state.ThrowDOMException( - DOMExceptionCode::kDataError, - "The bias type doesn't match input type."); - return nullptr; - } + internal_options->setFilterLayout(filter_layout); + + return BuildConv2d(this, MLOperator::OperatorKind::kConv2d, input, filter, + nullptr, nullptr, internal_options, exception_state); +} + +MLOperand* MLGraphBuilder::convTranspose2d(const MLOperand* input, + const MLOperand* filter, + const MLConvTranspose2dOptions* options, + ExceptionState& exception_state) { + // Unify the two convolutions which only differ in direction + // (forward vs backward) with a common internal representation + // that the backends more easily read. + MLConvOptionsInternal* internal_options = MLConvOptionsInternal::Create(); + internal_options->setAutoPad(options->autoPad()); + internal_options->setGroups(options->groups()); + internal_options->setInputLayout(options->inputLayout()); + + if (options->hasPadding()) + { + internal_options->setPadding(options->padding()); } - // Validate groups. - if (options->groups() == 0) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The groups should be greater than 0."); - return nullptr; + if (options->hasStrides()) + { + internal_options->setStrides(options->strides()); } - if (input_channels % options->groups() != 0 || - filter_input_channels != input_channels / options->groups()) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - "The groups must evenly divide the input " - "channels to filter input channels."); - return nullptr; + if (options->hasDilations()) + { + internal_options->setDilations(options->dilations()); } - - const auto output_sizes = ValidateAndCalculateConv2dOutputSizes( - input_height, input_width, filter_height, filter_width, - // If padding is not present, the values are assumed to be [0,0,0,0]. - options->getPaddingOr({0, 0, 0, 0}), - // If strides is not present, the values are assumed to be [1,1]. - options->getStridesOr({1, 1}), - // If dilations is not present, the values are assumed to be [1, 1]. - options->getDilationsOr({1, 1}), options->autoPad(), exception_state); - if (!output_sizes) { - return nullptr; + if (options->hasOutputPadding()) + { + internal_options->setOutputPadding(options->outputPadding()); } - const uint32_t output_height = - base::ClampFloor(output_sizes.value().height); - const uint32_t output_width = - base::ClampFloor(output_sizes.value().width); - // The input layout option specifies the layout format of the output tensor. - Vector output_shape; - switch (options->inputLayout().AsEnum()) { - case V8MLInputOperandLayout::Enum::kNchw: - // "nchw": [batches, output_channels, height, width] - output_shape = {input_batches, output_channels, output_height, - output_width}; - break; - case V8MLInputOperandLayout::Enum::kNhwc: - // "nhwc": [batches, height, width, output_channels] - output_shape = {input_batches, output_height, output_width, - output_channels}; - break; + if (options->hasOutputSizes()) + { + internal_options->setOutputSizes(options->outputSizes()); } - // Create conv2d operator and its output operand. Connect the conv2d - // operator to its input and output operands. - auto* conv2d = MakeGarbageCollected( - this, MLOperator::OperatorKind::kConv2d, options); - HeapVector> inputs = {input, filter}; - if (options->hasBias()) { - inputs.push_back(options->bias()); + if (options->hasBias()) + { + internal_options->setBias(options->bias()); } - String error_message; - auto* output = MLOperand::ValidateAndCreateOutput( - this, input->Type(), std::move(output_shape), conv2d, error_message); - if (!output) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - error_message); - return nullptr; + if (options->hasActivation()) + { + internal_options->setActivation(options->activation()); } - conv2d->Connect(std::move(inputs), {output}); - return output; + + V8MLConvFilterOperandLayoutInternal::Enum filter_layout; + static_assert(V8MLConvTranspose2dFilterOperandLayout::kEnumSize == 3); + switch (options->filterLayout().AsEnum()) { + default: + case V8MLConvTranspose2dFilterOperandLayout::Enum::kIohw: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kIohw; + break; + case V8MLConvTranspose2dFilterOperandLayout::Enum::kHwoi: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kHwoi; + break; + case V8MLConvTranspose2dFilterOperandLayout::Enum::kOhwi: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kOhwi; + break; + } + internal_options->setFilterLayout(filter_layout); + + return BuildConv2d(this, MLOperator::OperatorKind::kConvTranspose2d, input, + filter, nullptr, nullptr, internal_options, + exception_state); } #define BUILD_ELEMENTWISE_BINARY_OP(op, op_kind) \ MLOperand* MLGraphBuilder::op(const MLOperand* a, const MLOperand* b, \ ExceptionState& exception_state) { \ - return BuildElementWiseBinary(this, MLOperator::OperatorKind::op_kind, a, \ - b, exception_state); \ + return BuildElementwiseBinary(this, MLOperator::OperatorKind::op_kind, a, \ + b, a->Type(), exception_state); \ } BUILD_ELEMENTWISE_BINARY_OP(add, kAdd) @@ -855,7 +1540,8 @@ MLOperand* MLGraphBuilder::gemm(const MLOperand* a, shape_a[1], options->aTranspose() ? "transposed " : "", shape_b[0], options->bTranspose() ? "transposed " : "")); return nullptr; - }; + } + // The output is 2-D tensor of shape [M, N]. Vector output_shape = {shape_a[0], shape_b[1]}; // The third input tensor c is either a scalar, or of the shape that is @@ -944,6 +1630,13 @@ MLOperand* MLGraphBuilder::averagePool2d(const MLOperand* input, options, exception_state); } +MLOperand* MLGraphBuilder::l2Pool2d(const MLOperand* input, + const MLPool2dOptions* options, + ExceptionState& exception_state) { + return BuildPool2d(this, MLOperator::OperatorKind::kL2Pool2d, input, + options, exception_state); +} + MLOperand* MLGraphBuilder::maxPool2d(const MLOperand* input, const MLPool2dOptions* options, ExceptionState& exception_state) { @@ -980,7 +1673,7 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, const Vector& new_shape, ExceptionState& exception_state) { bool has_minus1 = false; - wtf_size_t minus1_dim_index; + wtf_size_t minus1_dim_index = 0; base::CheckedNumeric checked_newshape_number_of_elements = 1; Vector output_shape; if (new_shape.size() == 0) { @@ -993,6 +1686,10 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, // component of new shape can be the special value of -1. for (wtf_size_t i = 0; i < new_shape.size(); ++i) { auto d = new_shape[i]; + // TODO:::DELETE this special -1 behavior per the pending issue + // https://github.com/webmachinelearning/webnn/issues/388 + // to remove magic values like null or -1. +#if 1 // But keep it for now because Wanming may need it for the WebNN EP. if (d < -1 || d == 0) { exception_state.ThrowDOMException( DOMExceptionCode::kDataError, @@ -1007,6 +1704,13 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, } has_minus1 = true; minus1_dim_index = i; +#else + if (d <= 0) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The value of new shape may not have 0 in it."); + return nullptr; +#endif } else { checked_newshape_number_of_elements *= d; output_shape[i] = d; @@ -1018,7 +1722,7 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, &newshape_number_of_elements)) { exception_state.ThrowDOMException( DOMExceptionCode::kDataError, - "The number of elements implied by new shape is too large."); + "The number of elements in the new shape is too large."); return nullptr; } DCHECK_NE(newshape_number_of_elements, size_t(0)); @@ -1030,8 +1734,8 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, DOMExceptionCode::kDataError, String::Format( "The number of elements (%zu) in the input tensor can't be " - "divided evenly by the number of elements (%zu) implied by new " - "shape.", + "divided evenly by the number of elements (%zu) in the " + "new shape.", input->NumberOfElements(), newshape_number_of_elements)); return nullptr; } @@ -1051,7 +1755,7 @@ MLOperand* MLGraphBuilder::reshape(const MLOperand* input, exception_state.ThrowDOMException( DOMExceptionCode::kDataError, String::Format( - "The number of elements (%zu) implied by new shape doesn't match " + "The number of elements (%zu) in the new shape doesn't match " "the number of elements (%zu) in the input tensor.", newshape_number_of_elements, input->NumberOfElements())); return nullptr; @@ -1077,7 +1781,7 @@ MLOperand* MLGraphBuilder::resample2d(const MLOperand* input, // According to WebNN spec: // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-resample2d, the input // must be a 4-D tensor. - const auto input_shape = input->Dimensions(); + const auto& input_shape = input->Dimensions(); if (input_shape.size() != 4) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, "The input must be a 4-D tensor."); @@ -1085,6 +1789,11 @@ MLOperand* MLGraphBuilder::resample2d(const MLOperand* input, } const auto axes = options->getAxesOr({2, 3}); + const wtf_size_t input_rank = input_shape.size(); + if (!ValidateAxes(axes, input_rank, "resample2d", exception_state)) { + return nullptr; + } + if (axes.size() != 2) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, "The length of axes should be 2."); @@ -1101,6 +1810,8 @@ MLOperand* MLGraphBuilder::resample2d(const MLOperand* input, } Vector output_shape(input_shape); + Vector scales(axes.size(), 1.0f); + if (options->hasSizes()) { if (options->hasScales()) { auto* execution_context = GetContext()->GetML()->GetExecutionContext(); @@ -1125,10 +1836,17 @@ MLOperand* MLGraphBuilder::resample2d(const MLOperand* input, "All sizes should be greater than 0."); return nullptr; } + output_shape[axes[0]] = options->sizes()[0]; output_shape[axes[1]] = options->sizes()[1]; + + // Compute the scales from the new shape. + scales[0] = + static_cast(output_shape[axes[0]]) / input_shape[axes[0]]; + scales[1] = + static_cast(output_shape[axes[1]]) / input_shape[axes[1]]; } else { - const auto scales = options->getScalesOr({1.0f, 1.0f}); + scales = options->getScalesOr({1.0f, 1.0f}); if (scales.size() != 2) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, "The length of scales should be 2."); @@ -1155,8 +1873,17 @@ MLOperand* MLGraphBuilder::resample2d(const MLOperand* input, } } + // Pass the normalized options onward, simplifying the lower level's job. + // Then the axes parameter and scales consistently exist. + MLResample2dOptions* normalized_options = MLResample2dOptions::Create(); + normalized_options->setMode(options->mode()); + normalized_options->setAxes(axes); + normalized_options->setScales(scales); + // Do not set sizes, since the output shape is already set, + // and since it would potentially conflict with scales. + auto* resample2d = MakeGarbageCollected( - this, MLOperator::OperatorKind::kResample2d, options); + this, MLOperator::OperatorKind::kResample2d, normalized_options); String error_message; // According to WebNN spec // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-resample2d, the output @@ -1189,25 +1916,18 @@ MLOperand* MLGraphBuilder::softmax(const MLOperand* input, "The input type must be one of the floating point types."); return nullptr; } - auto* softmax = MakeGarbageCollected( - this, MLOperator::OperatorKind::kSoftmax); - // The output tensor has the same shape as the input tensor. - String error_message; - auto* output = MLOperand::ValidateAndCreateOutput( - this, input->Type(), input->Dimensions(), softmax, error_message); - if (!output) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - error_message); - return nullptr; - } - softmax->Connect({input}, {output}); - return output; + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSoftmax, input, + exception_state); +} + +MLOperator* MLGraphBuilder::softmax(ExceptionState& exception_state) { + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kSoftmax); } MLOperand* MLGraphBuilder::sigmoid(const MLOperand* input, ExceptionState& exception_state) { - auto* sigmoid = MakeGarbageCollected( - this, MLOperator::OperatorKind::kSigmoid); // According to WebNN spec // https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-sigmoid, the // output tensor of sigmoid has the same type and dimensions as its input. @@ -1218,24 +1938,1566 @@ MLOperand* MLGraphBuilder::sigmoid(const MLOperand* input, "The input type must be one of the floating point types."); return nullptr; } + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSigmoid, input, + exception_state); +} + +MLOperator* MLGraphBuilder::sigmoid(ExceptionState& exception_state) { + // Create the sigmoid operator that would be used as an activation function. + return MakeGarbageCollected(this, + MLOperator::OperatorKind::kSigmoid); +} + +MLOperand* MLGraphBuilder::elu(const MLOperand* input, + const MLEluOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertEluOptions(options, exception_state); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kElu, input, + input->Dimensions(), input->Type(), internal_options, + exception_state); +} + +MLOperator* MLGraphBuilder::elu(const MLEluOptions* options, ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertEluOptions(options, exception_state); + + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kElu, internal_options); +} + +MLOperand* MLGraphBuilder::hardSigmoid(const MLOperand* input, + const MLHardSigmoidOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertHardSigmoidOptions(options, exception_state); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kHardSigmoid, input, + input->Dimensions(), input->Type(), internal_options, + exception_state); +} + +MLOperator* MLGraphBuilder::hardSigmoid(const MLHardSigmoidOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertHardSigmoidOptions(options, exception_state); + + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kHardSigmoid, internal_options); +} + +MLOperand* MLGraphBuilder::leakyRelu(const MLOperand* input, + const MLLeakyReluOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertLeakyReluOptions(options, exception_state); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kLeakyRelu, input, + input->Dimensions(), input->Type(), internal_options, + exception_state); +} + +MLOperator* MLGraphBuilder::leakyRelu(const MLLeakyReluOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertLeakyReluOptions(options, exception_state); + + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kLeakyRelu, internal_options); +} + +MLOperand* MLGraphBuilder::linear(const MLOperand* input, + const MLLinearOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertLinearOptions(options, exception_state); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kLinear, input, + input->Dimensions(), input->Type(), internal_options, + exception_state); +} + +MLOperator* MLGraphBuilder::linear(const MLLinearOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertLinearOptions(options, exception_state); + + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kLinear, internal_options); +} + +MLOperand* MLGraphBuilder::prelu(const MLOperand* input, + const MLOperand* slope, + ExceptionState& exception_state) { + absl::optional> output_dimensions = + BroadcastShapes(slope->Dimensions(), input->Dimensions(), false); + if (!output_dimensions) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The prelu slope tensor is not broadcastable to the input tensor."); + return nullptr; + } + + return BuildElementwiseBinary(this, MLOperator::OperatorKind::kPrelu, input, + slope, input->Type(), exception_state); +} + +MLOperand* MLGraphBuilder::softplus(const MLOperand* input, + const MLSoftplusOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertSoftplusOptions(options, exception_state); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSoftplus, input, + input->Dimensions(), input->Type(), + internal_options, exception_state); +} + +MLOperator* MLGraphBuilder::softplus(const MLSoftplusOptions* options, + ExceptionState& exception_state) { + MLFloatParameterOptionsInternal* internal_options = + ConvertSoftplusOptions(options, exception_state); + + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kSoftplus, internal_options); +} + +MLOperand* MLGraphBuilder::softsign(const MLOperand* input, ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSoftsign, input, + exception_state); +} + +MLOperator* MLGraphBuilder::softsign(ExceptionState& exception_state) { + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kSoftsign); +} + +MLOperand* MLGraphBuilder::tanh(const MLOperand* input, ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kTanh, input, + exception_state); +} + +MLOperator* MLGraphBuilder::tanh(ExceptionState& exception_state) { + return MakeGarbageCollected( + this, MLOperator::OperatorKind::kTanh); +} + +MLOperand* MLGraphBuilder::elementwiseIf(const MLOperand* condition, + const MLOperand* true_value, + const MLOperand* false_value, + ExceptionState& exception_state) { + if (!IsBooleanType(condition->Type())) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The input condition type must be a Boolean data type."); + return nullptr; + } + + if (true_value->Type() != false_value->Type()) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + "The input types don't match."); + return nullptr; + } + absl::optional> value_dimensions = + BroadcastShapes(true_value->Dimensions(), false_value->Dimensions()); + if (!value_dimensions) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The input shapes are not broadcastable."); + return nullptr; + } + absl::optional> output_dimensions = + BroadcastShapes(condition->Dimensions(), *value_dimensions); + if (!output_dimensions) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The input shapes are not broadcastable."); + return nullptr; + } + + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kElementWiseIf); + String error_message; + auto* output = MLOperand::ValidateAndCreateOutput(this, true_value->Type(), + output_dimensions.value(), + ml_operator, error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + ml_operator->Connect({condition, true_value, false_value}, {output}); + return output; +} + +MLOperand* MLGraphBuilder::argMax(const MLOperand* input, + const MLArgMinMaxOptions* options, + ExceptionState& exception_state) { + return BuildArgMinMax(this, MLOperator::OperatorKind::kArgMax, "argMax", input, options, + exception_state); +} + +MLOperand* MLGraphBuilder::argMin(const MLOperand* input, + const MLArgMinMaxOptions* options, + ExceptionState& exception_state) { + return BuildArgMinMax(this, MLOperator::OperatorKind::kArgMin, "argMin", input, options, + exception_state); +} + +MLOperand* MLGraphBuilder::cast(const MLOperand* input, + V8MLOperandType data_type, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kCast, input, + input->Dimensions(), data_type.AsEnum(), /*options*/ nullptr, + exception_state); +} + +MLOperand* MLGraphBuilder::concat(const HeapVector>& inputs, + uint32_t axis, + ExceptionState& exception_state) { + wtf_size_t input_count = inputs.size(); + if (input_count <= 0) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "Concat requires at least one input."); + return nullptr; + } + + // Set the output dimensions initially to the first input, + // concatenating each successive one in the loop below. + auto& first_input = inputs.front(); + Vector output_dimensions = first_input->Dimensions(); + + if (!ValidateAxis(axis, + output_dimensions.size(), + "concat", + exception_state)) { + return nullptr; + } + + base::CheckedNumeric checked_output_axis_length = 0; + + // Validate input dimensions are compatible with each other, and compute the + // total length of the active axis dimension. + for (wtf_size_t i = 0; i < input_count; ++i) { + auto& input = inputs[i]; + auto& input_dimensions = input->Dimensions(); + + if (input_dimensions.size() != output_dimensions.size()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("All input tensors must have the same size. Input %u " + "has a size of %u but input 0 has a size of %u.", + i, input_dimensions.size(), output_dimensions.size())); + return nullptr; + } + + checked_output_axis_length += input_dimensions[axis]; + } + + // Set the length of the active axis. + uint32_t output_axis_length; + if (!checked_output_axis_length.AssignIfValid(&output_axis_length)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The number of elements in the new shape is too large."); + return nullptr; + } + output_dimensions[axis] = output_axis_length; + + MLConcatOptionsInternal* options = MLConcatOptionsInternal::Create(); + options->setAxis(axis); + + String error_message; + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kConcat, options); + + auto* output = MLOperand::ValidateAndCreateOutput( + this, first_input->Type(), std::move(output_dimensions), ml_operator, + /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + + HeapVector> copied_inputs(inputs); + ml_operator->Connect(std::move(copied_inputs), {output}); + return output; +} + +MLOperand* MLGraphBuilder::expand(const MLOperand* input, + const Vector& new_shape, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const auto new_shape_dimension_count = new_shape.size(); + base::CheckedNumeric checked_new_shape_number_of_elements = 1; + + if (new_shape_dimension_count != input_dimensions.size()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The new shape's dimension count (%u) must match the " + "input tensor's (%u).", + new_shape_dimension_count, + input->Dimensions().size())); + return nullptr; + } + + for (wtf_size_t i = 0; i < new_shape_dimension_count; ++i) { + auto old_size = input_dimensions[i]; + auto new_size = new_shape[i]; + if (new_size < old_size || (new_size > old_size && old_size != 1)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "The each value in the new shape (%u) must either equal the old " + "shape (%u) " + "or broadcast a single size dimension input to a greater size.", + new_size, old_size)); + return nullptr; + } + checked_new_shape_number_of_elements *= new_size; + } + + // Check for overflow. + size_t new_shape_number_of_elements; + if (!checked_new_shape_number_of_elements.AssignIfValid( + &new_shape_number_of_elements)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The number of elements in the new shape is too large."); + return nullptr; + } + + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kExpand); String error_message; auto* output = MLOperand::ValidateAndCreateOutput( - this, input->Type(), input->Dimensions(), sigmoid, error_message); + this, input->Type(), std::move(new_shape), ml_operator, error_message); if (!output) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, error_message); return nullptr; } - sigmoid->Connect({input}, {output}); + ml_operator->Connect({input}, {output}); return output; } -MLOperator* MLGraphBuilder::sigmoid(ExceptionState& exception_state) { - // Create the sigmoid operator that would be used as an activation function. - return MakeGarbageCollected(this, - MLOperator::OperatorKind::kSigmoid); +MLOperand* MLGraphBuilder::abs(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kAbs, input, + exception_state); +} + +MLOperand* MLGraphBuilder::neg(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kNeg, input, + exception_state); +} + +MLOperand* MLGraphBuilder::cos(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kCos, input, + exception_state); +} + +MLOperand* MLGraphBuilder::erf(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kErf, input, + exception_state); +} + +MLOperand* MLGraphBuilder::exp(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kExp, input, + exception_state); +} + +MLOperand* MLGraphBuilder::log(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kLog, input, + exception_state); +} + +MLOperand* MLGraphBuilder::floor(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kFloor, input, + exception_state); +} + +MLOperand* MLGraphBuilder::ceil(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kCeil, input, + exception_state); +} + +MLOperand* MLGraphBuilder::reciprocal(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kReciprocal, input, + exception_state); +} + +MLOperand* MLGraphBuilder::logicalNot(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kLogicalNot, input, + exception_state); +} + +MLOperand* MLGraphBuilder::flattenTo2d(const MLOperand* input, + uint32_t axis, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + if (axis > input_rank) { + // Cannot call ValidateAxis() here because 'axis' ranges [0, input rank]. + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The flattenTo2d axis (%u) must range from 0 up to the " + "dimension count (%u).", + axis, input_rank)); + return nullptr; + } + + // Flatten the leading and trailing portion of the dimensions + // (where axis is the split point) into a 2D tensor. + Vector output_dimensions(2, 0u); + base::span input_dimensions_span = input_dimensions; + base::span leading_dimensions = input_dimensions_span.first(axis); + base::span trailing_dimensions = input_dimensions_span.subspan(axis); + output_dimensions[0] = ComputeElementCount(leading_dimensions); + output_dimensions[1] = ComputeElementCount(trailing_dimensions); + + // Resolve flattenTo2d into a reshape operator. + return BuildUnaryOperator(this, MLOperator::OperatorKind::kReshape, input, + output_dimensions, input->Type(), /*options*/nullptr, + exception_state); +} + +MLOperand* MLGraphBuilder::gather(const MLOperand* input, + const MLOperand* indices, + const MLGatherOptions* options, + ExceptionState& exception_state) { + wtf_size_t input_rank = input->Dimensions().size(); + wtf_size_t indices_rank = indices->Dimensions().size(); // >= 0 + wtf_size_t output_rank = input_rank + indices_rank - 1; + uint32_t axis = options->axis(); + + if (!IsIndexType(indices->Type())) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "Gather's indices element type must be int32/uint32."); + return nullptr; + } + + if (input_rank <= 0) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "Gather's input rank (%u) requires at least 1 dimension.", + input_rank)); + return nullptr; + } + if (axis >= input_rank) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "Gather's axis (%u) must be within the input tensor rank (%u).", + axis, input_rank)); + return nullptr; + } + if (input_rank + indices_rank < 1) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("Gather's input rank (%u) and indices rank (%u) " + "combined must be at least 1.", + input_rank, indices_rank)); + return nullptr; + } + + const Vector& inputDimensions = input->Dimensions(); + const Vector& indicesDimensions = indices->Dimensions(); + Vector output_dimensions(output_rank, 1u); + + // The input dimensions following the gather axis determine the final output + // dimensions. + int32_t output_dimension = output_rank - 1; + int32_t input_dimension = input_rank - 1; + for (; input_dimension > int32_t(axis); + --output_dimension, --input_dimension) { + output_dimensions[output_dimension] = inputDimensions[input_dimension]; + } + + // The shape of the index tensor is reflected in the middle dimensions of the + // output tensor. + int32_t index_dimension = indices_rank - 1; + for (; index_dimension >= 0; --output_dimension, --index_dimension) { + output_dimensions[output_dimension] = indicesDimensions[index_dimension]; + } + + // The gather dimension is skipped for the purposes of sizing because the + // index values choose slices across it. Preceding input dimensions + // determine the shape of the output's leading dimensions. + input_dimension = axis - 1; + for (; output_dimension >= 0 && input_dimension >= 0; + --output_dimension, --input_dimension) { + output_dimensions[output_dimension] = inputDimensions[input_dimension]; + } + + String error_message; + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kGather, options); + + auto* output = MLOperand::ValidateAndCreateOutput( + this, input->Type(), std::move(output_dimensions), ml_operator, + /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + + ml_operator->Connect({input, indices}, {output}); + return output; +} + +MLOperand* MLGraphBuilder::equal(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state) { + return BuildElementwiseBinary(this, MLOperator::OperatorKind::kEqual, a, b, + V8MLOperandType::Enum::kUint8, + exception_state); +} + +MLOperand* MLGraphBuilder::greater(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state) { + return BuildElementwiseBinary(this, MLOperator::OperatorKind::kGreater, a, b, + V8MLOperandType::Enum::kUint8, + exception_state); +} + +MLOperand* MLGraphBuilder::lesser(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state) { + return BuildElementwiseBinary(this, MLOperator::OperatorKind::kLesser, a, b, + V8MLOperandType::Enum::kUint8, + exception_state); +} + +MLOperand* MLGraphBuilder::identity(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kIdentity, input, + exception_state); } +MLOperand* MLGraphBuilder::instanceNormalization( + const MLOperand* input, + const MLInstanceNormalizationOptions* options, + ExceptionState& exception_state) { + auto& input_dimensions = input->Dimensions(); + + if (input_dimensions.size() != 4) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "instanceNormalization's input tensor rank (%u) should be 4D.", + input_dimensions.size())); + return nullptr; + } + + // Determine which axis of the input holds the feature axis. + static_assert(uint32_t(V8MLInputOperandLayout::kEnumSize) == 2, "Update switch for the new layout."); + uint32_t scale_bias_dimension = 0; + switch (options->layout().AsEnum()) + { + case V8MLInputOperandLayout::Enum::kNchw: + scale_bias_dimension = 1; // Channel after batch. + break; + case V8MLInputOperandLayout::Enum::kNhwc: + scale_bias_dimension = 3; // Channel last. + break; + default: + NOTREACHED(); + } + + const uint32_t expected_scale_bias_length = input_dimensions[scale_bias_dimension]; + + // Expect a 1D array for scale and bias, equal to the feature axis. + auto verify_scale_or_bias = [&](MLOperand& ml_operand, const char* tensor_name) + { + auto& dimensions = ml_operand.Dimensions(); + if (dimensions.size() != 1 || dimensions.front() != expected_scale_bias_length) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "instanceNormalization's %s tensor rank (%u) should be 4D.", + tensor_name, + dimensions.size())); + return false; + } + return true; + }; + + // Collect the inputs, with optional scale and bias. + HeapVector> inputs = {input}; + + if (options->hasScale()) { + if (!verify_scale_or_bias(*options->scale(), "scale")) + { + return nullptr; + } + inputs.push_back(options->scale()); + } + if (options->hasBias()) { + if (!verify_scale_or_bias(*options->bias(), "bias")) + { + return nullptr; + } + inputs.push_back(options->bias()); + } + + // Create the instance normalization operator, and connect IO. + auto* ml_operator = MakeGarbageCollected(this, MLOperator::OperatorKind::kInstanceNormalization, options); + String error_message; + auto* output = MLOperand::ValidateAndCreateOutput(this, input->Type(), input_dimensions, ml_operator, error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, error_message); + return nullptr; + } + ml_operator->Connect(std::move(inputs), {output}); + return output; +} + +MLOperand* MLGraphBuilder::meanVarianceNormalization( + const MLOperand* input, + const MLMeanVarianceNormalizationOptions* options, + ExceptionState& exception_state) { + auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + // Verify the axes are within the input rank and not duplicated. + Vector axes; + if (options->hasAxes()) + { + axes = options->axes(); + if (!ValidateAxes(axes, input_rank, "meanVarianceNormalization", exception_state)) { + return nullptr; + } + } + else // Reduce all dimensions if permutations are missing, consistent with reduction operators. + { + axes.resize(input_rank); + std::iota(axes.begin(), axes.end(), 0u); + } + + // Expect a 1D array for scale and bias, equal to the feature axis. + auto verify_compatible_input_size = [&](MLOperand& ml_operand, + const char* tensor_name) { + auto& dimensions = ml_operand.Dimensions(); + if (dimensions.size() != input_dimensions.size()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("meanVarianceNormalization's %s tensor must " + "match the input tensor rank (%u), not (%u).", + tensor_name, input_dimensions.size(), + dimensions.size())); + return false; + } + + for (wtf_size_t i = 0, rank = input_dimensions.size(); i < rank; ++i) { + if (dimensions[i] != input_dimensions[i] && dimensions[i] != 1) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "meanVarianceNormalization's %s tensor dimension (%u) " + "must be either 1 or match the corresponding input " + "dimension (%u).", + tensor_name, dimensions[i], input_dimensions[i])); + return false; + } + } + return true; + }; + + // Collect the inputs, with optional scale and bias. + HeapVector> inputs = {input}; + + // Pass the normalized options onward, simplifying the lower level's job. + MLMeanVarianceNormalizationOptions* normalized_options = + MLMeanVarianceNormalizationOptions::Create(); + + if (options->hasMean()) { + if (!verify_compatible_input_size(*options->mean(), "mean")) { + return nullptr; + } + inputs.push_back(options->mean()); + normalized_options->setMean(options->mean()); + } + if (options->hasVariance()) { + if (!verify_compatible_input_size(*options->variance(), "variance")) { + return nullptr; + } + inputs.push_back(options->variance()); + normalized_options->setVariance(options->variance()); + } + if (options->hasScale()) { + if (!verify_compatible_input_size(*options->scale(), "scale")) { + return nullptr; + } + inputs.push_back(options->scale()); + normalized_options->setScale(options->scale()); + } + if (options->hasBias()) { + if (!verify_compatible_input_size(*options->bias(), "bias")) { + return nullptr; + } + inputs.push_back(options->bias()); + normalized_options->setBias(options->bias()); + } + + // Verify all tensors are passed if mean and variance were precomputed. + // Otherwise scale and bias are independently optional. + if (options->hasMean() || options->hasVariance()) { + if (!options->hasScale() || !options->hasBias() || !options->hasMean() || + !options->hasVariance()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "meanVarianceNormalization requires that if either mean or variance " + "are passed that all tensors be passed."); + return nullptr; + } + } + + normalized_options->setEpsilon(options->epsilon()); + normalized_options->setAxes(axes); + + // Create the mean variance normalization operator, and connect IO. + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kMeanVarianceNormalization, + normalized_options); + String error_message; + auto* output = MLOperand::ValidateAndCreateOutput( + this, input->Type(), input_dimensions, ml_operator, error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + ml_operator->Connect(std::move(inputs), {output}); + return output; +} + +MLOperand* BuildMatMul(MLGraphBuilder* builder, + MLOperator::OperatorKind operator_kind, + const MLOperand* a, + const MLOperand* b, + const MLOperand* a_zero_point, // optional + const MLOperand* b_zero_point, // optional + ExceptionState& exception_state) { + // Massage the two input tensor's rank accordingly: + // - If a is 1-D, it is converted to a 2-D tensor by prepending a 1 to its dimensions. + // - If b is 1-D, it is converted to a 2-D tensor by by appending a 1 to its dimensions. + // - If either a or b have rank N > 2, the higher dimensions are broadcast to each other, + // with the output rank being the greater of the two. + // Then the inputs are treated as a stack of matrices. If both were 1D, it's treated as + // as dot product (which happens naturally as a by-product of expansion). + + if (a->Type() != b->Type()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The types of first two inputs don't match."); + return nullptr; + } + + Vector a_dimensions = a->Dimensions(); + Vector b_dimensions = b->Dimensions(); + // MLGraphBuilder::input should have coerced to 1 already. + DCHECK_GT(a_dimensions.size(), uint32_t(0)); + DCHECK_GT(b_dimensions.size(), uint32_t(0)); + + // Massage the sizes first, before additional broadcastability checks. + // After this point, both arrays are at least the same size, simplifying + // later checks in the code. + wtf_size_t output_rank = std::max(a_dimensions.size(), b_dimensions.size()); + if (a_dimensions.size() == 1) { + a_dimensions.push_front(1u); + } + if (a_dimensions.size() == 1) { + b_dimensions.push_back(1u); + } + a_dimensions = ExpandDimensions(a_dimensions, output_rank); + b_dimensions = ExpandDimensions(b_dimensions, output_rank); + + // The number of columns in the first matrix must be equal to the number of + // rows in the second matrix. + const uint32_t a_cols = a_dimensions[output_rank - 1]; + const uint32_t a_rows = a_dimensions[output_rank - 2]; + const uint32_t b_cols = b_dimensions[output_rank - 1]; + const uint32_t b_rows = b_dimensions[output_rank - 2]; + if (a_cols != b_rows) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "The number of columns (%u) in the first matrix isn't equal to " + "the number of rows (%u) in the second matrix.", + a_cols, b_rows)); + return nullptr; + } + + // Figure out the output shape by broadcasting all the dimensions except the + // last two. The output is 2-D tensor of shape [M, N]. + absl::optional> optional_output_dimensions = + BroadcastShapes(a_dimensions, b_dimensions, true, 2); + if (!optional_output_dimensions) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The matmul input shapes are not broadcastable."); + return nullptr; + } + auto& output_dimensions = *optional_output_dimensions; + DCHECK(output_rank == output_dimensions.size()); + output_dimensions[output_rank - 2] = a_rows; + output_dimensions[output_rank - 1] = b_cols; + + String error_message; + + // Create empty options for gemm, since MatMul uses the defaults. + auto* options = MLGemmOptions::Create(); + + auto* ml_operator = MakeGarbageCollected( + builder, operator_kind, options); + HeapVector> inputs = {a, b}; + + // Append zero points if present. + if (operator_kind == MLOperator::OperatorKind::kMatmulInteger) { + inputs.resize(4); + inputs[2] = a_zero_point; + inputs[3] = b_zero_point; // TODO::: Restrict size even more, to scalars? + + if (a_zero_point && !ValidateUnidirectionalBroadcastability( + a_zero_point->Dimensions(), a_dimensions, "matmul", + "a_zero_point", "a", exception_state)) { + return nullptr; + } + if (b_zero_point && !ValidateUnidirectionalBroadcastability( + b_zero_point->Dimensions(), b_dimensions, "matmul", + "b_zero_point", "b", exception_state)) { + return nullptr; + } + } + + V8MLOperandType::Enum output_data_type = + (operator_kind == MLOperator::OperatorKind::kMatmulInteger) + ? V8MLOperandType::Enum::kInt32 + : a->Type(); + + auto* output = MLOperand::ValidateAndCreateOutput( + builder, output_data_type, std::move(output_dimensions), ml_operator, + error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + ml_operator->Connect(std::move(inputs), {output}); + return output; +} + +MLOperand* MLGraphBuilder::matmul(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state) { + return BuildMatMul(this, MLOperator::OperatorKind::kMatmul, a, b, + nullptr, // No a zero point + nullptr, // No b zero point + exception_state); +} + +MLOperand* MLGraphBuilder::pad(const MLOperand* input, + const Vector& beginning_padding, + const Vector& ending_padding, + const MLPadOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + Vector output_dimensions = input_dimensions; + if (beginning_padding.size() != input_rank || + ending_padding.size() != input_rank) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("Pad's beginning padding length (%u) and ending padding " + "length (%u) must match the input rank (%u).", + beginning_padding.size(), ending_padding.size(), + input_rank)); + return nullptr; + } + + // Apply padding to output dimensions. + for (wtf_size_t i = 0; i < input_rank; ++i) { + base::CheckedNumeric checked_output_dimension = + output_dimensions[i]; + checked_output_dimension += beginning_padding[i]; + checked_output_dimension += ending_padding[i]; + + uint32_t output_dimension; + if (!checked_output_dimension.AssignIfValid(&output_dimension)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "Pad's padding overflowed the maximum size (%u + %u + %u).", + beginning_padding[i], ending_padding[i], output_dimensions[i])); + return nullptr; + } + + output_dimensions[i] = output_dimension; + } + + MLPadOptionsInternal* internal_options = MLPadOptionsInternal::Create(); + internal_options->setMode(options->mode()); + internal_options->setValue(options->value()); + internal_options->setBeginningPadding(beginning_padding); + internal_options->setEndingPadding(ending_padding); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kPad, input, + output_dimensions, input->Type(), + /*options*/ internal_options, exception_state); +} + +MLOperand* MLGraphBuilder::pow(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state) { + return BuildElementwiseBinary(this, MLOperator::OperatorKind::kPow, a, b, a->Type(), + exception_state); +} + +MLOperand* MLGraphBuilder::fillSequence(V8MLOperandType output_data_type, + const Vector& output_shape, + const MLFillSequenceOptions* options, + ExceptionState& exception_state) { + String error_message; + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kFillSequence, options); + + auto* output = MLOperand::ValidateAndCreateOutput(this, output_data_type.AsEnum(), + output_shape, ml_operator, + /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + + ml_operator->Connect({}, {output}); + return output; +} + +MLOperand* MLGraphBuilder::reduceL1(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceL1, + "reduceL1", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceL2(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceL2, + "reduceL2", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceLogSum(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceLogSum, + "reduceLogSum", input, options, + exception_state); +} + +MLOperand* MLGraphBuilder::reduceLogSumExp(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator( + this, MLOperator::OperatorKind::kReduceLogSumExp, "reduceLogSumExp", + input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceMax(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceMax, + "reduceMax", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceMean(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceMean, + "reduceMean", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceMin(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceMin, + "reduceMin", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceProduct(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceProduct, + "reduceProduct", input, options, + exception_state); +} + +MLOperand* MLGraphBuilder::reduceSum(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator(this, MLOperator::OperatorKind::kReduceSum, + "reduceSum", input, options, exception_state); +} + +MLOperand* MLGraphBuilder::reduceSumSquare(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state) { + return BuildReductionOperator( + this, MLOperator::OperatorKind::kReduceSumSquare, "reduceSumSquare", + input, options, exception_state); +} + +Vector MLGraphBuilder::shape( + const MLOperand* input, + ExceptionState& exception_state) { + return input->Dimensions(); +} + +MLOperand* MLGraphBuilder::sin(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSin, input, + exception_state); +} + +MLOperand* MLGraphBuilder::slice(const MLOperand* input, + const Vector& starts, + const Vector& sizes, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + if (starts.size() != input_rank || sizes.size() != input_rank) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("Slice's starts length (%u) and sizes length (%u) must " + "match the input rank (%u).", + starts.size(), sizes.size(), input_rank)); + return nullptr; + } + + // Ensure starts and sizes are within valid dimensions. + for (wtf_size_t i = 0; i < input_rank; ++i) + { + uint32_t dimension_length = input_dimensions[i]; + if (sizes[i] > dimension_length || + sizes[i] == 0 || + starts[i] > dimension_length - sizes[i]) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("Slice's starts (%u) and sizes (%u) must fit within the dimension length (%u) and be non-empty.", + starts[i], sizes[i], dimension_length)); + return nullptr; + } + } + + MLSliceOptionsInternal* options = MLSliceOptionsInternal::Create(); + options->setStarts(starts); + options->setSizes(sizes); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSlice, input, + sizes, input->Type(), /*options*/ options, + exception_state); +} + +HeapVector> MLGraphBuilder::split( + const MLOperand* input, + blink::V8UnionUnsignedLongOrUnsignedLongSequence* splits, + const MLSplitOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + const uint32_t axis = options->hasAxis() ? options->axis() : 0; + + if (!ValidateAxis(axis, input_rank, "split", exception_state)) { + return {}; + } + + const auto is_split_single_scalar = splits->IsUnsignedLong(); + const auto input_axis_length = input_dimensions[axis]; + Vector resolved_splits; + + if (is_split_single_scalar) { + // Convert from a single split count to a splits array. + const wtf_size_t split_count = splits->GetAsUnsignedLong(); + if (split_count <= 0 || split_count > input_axis_length) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The split count (%u) must be greater than 0 and no " + "greater than the axis length (%u).", + split_count, input_axis_length)); + return {}; + } + + uint32_t axis_length_per_output = input_axis_length / split_count; + if (input_axis_length % split_count != 0) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The split count (%u) must divide evenly into the " + "axis length (%u).", + split_count, input_axis_length)); + return {}; + } + resolved_splits.resize(split_count); + resolved_splits.Fill(axis_length_per_output); + } else { + base::CheckedNumeric checked_output_axis_length = 0; + + resolved_splits = splits->GetAsUnsignedLongSequence(); + if (resolved_splits.empty()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The split count must be greater than 0."); + return {}; + } + + for (wtf_size_t i = 0, split_count = resolved_splits.size(); + i < split_count; ++i) { + const auto split_size = resolved_splits[i]; + if (split_size <= 0) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "All split sizes must be >= 0 but index %u is zero.", i)); + return {}; + } + + checked_output_axis_length += split_size; + } + + // Set the length of the active axis. + uint32_t output_axis_length; + if (!checked_output_axis_length.AssignIfValid(&output_axis_length)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The number of split values is too large."); + return {}; + } + + if (output_axis_length != input_axis_length) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format("The sum total of split sizes (%u) must equal the " + "input axis length (%u).", output_axis_length, + input_axis_length)); + return {}; + } + } + + // Normalize the options so backends have a simpler time. + MLSplitOptionsInternal* normalized_options = MLSplitOptionsInternal::Create(); + normalized_options->setAxis(axis); + normalized_options->setSplits(resolved_splits); + + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kSplit, normalized_options); + + // Create multiple output tensors. + String error_message; + HeapVector> outputs; + + for (wtf_size_t i = 0, split_count = resolved_splits.size(); i < split_count; ++i) { + Vector output_dimensions = input_dimensions; + output_dimensions[axis] = resolved_splits[i]; + + MLOperand* output = MLOperand::ValidateAndCreateOutput( + this, input->Type(), std::move(output_dimensions), ml_operator, + /*out*/ error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return {}; + } + outputs.push_back(output); + } + + HeapVector> copied_outputs(outputs); + ml_operator->Connect({input}, std::move(copied_outputs)); + return outputs; +} + +MLOperand* MLGraphBuilder::sqrt(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kSqrt, input, + exception_state); +} + +MLOperand* MLGraphBuilder::squeeze( + const MLOperand* input, + const MLSqueezeOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + uint32_t axes_mask = 0xFFFFFFFF; // Remove all axes by default, if none passed. + if (options->hasAxes()) { + auto& axes = options->axes(); + if (!ValidateAxes(axes, input_rank, "squeeze", exception_state)) { + return nullptr; + } + if (!ValidateAxesMask(options->axes(), + "squeeze", + exception_state, + /*out*/ axes_mask)) { + return nullptr; + } + } + + // Strip any dimensions of size 1 from the output. + Vector output_dimensions = input_dimensions; + for (wtf_size_t i = 0, output_rank = output_dimensions.size(); i < output_rank; ) + { + if (axes_mask & (1 << i) && output_dimensions[i] == 1u) { + output_dimensions.EraseAt(i); + --output_rank; + } + else + { + ++i; // Preserve this dimension. + } + } + + // Resolve squeeze into a reshape operator. + return BuildUnaryOperator(this, MLOperator::OperatorKind::kReshape, input, + output_dimensions, input->Type(), /*options*/nullptr, + exception_state); +} + +MLOperand* MLGraphBuilder::tan(const MLOperand* input, + ExceptionState& exception_state) { + return BuildUnaryOperator(this, MLOperator::OperatorKind::kTan, input, + exception_state); +} + +MLOperand* MLGraphBuilder::transpose( + const MLOperand* input, + const MLTransposeOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + Vector permutation; + + // Verify the permutations are within the input rank and not duplicated. + if (options->hasPermutation()) + { + permutation = options->permutation(); + if (permutation.size() != input_rank) + { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + String::Format( + "Transposes's permutation rank (%u) must must match the input rank (%u).", + permutation.size(), input_rank)); + return nullptr; + } + + Vector seen_axes(input_rank); + + if (!ValidateAxes(permutation, input_rank, "transpose", exception_state)) + { + return nullptr; + } + } + else // Reverse all dimensions if permutations are missing. + { + for (wtf_size_t i = 0; i < input_rank; ++i) + { + permutation.push_back(input_rank - i - 1); + } + } + + // Permute the dimensions. + Vector output_dimensions(input_rank); + for (wtf_size_t i = 0; i < input_rank; ++i) + { + output_dimensions[i] = input_dimensions[permutation[i]]; + } + + // Pass the normalized options onward, simplifying the lower level's job. + // Then the permutation parameter consistently exists. + MLTransposeOptions* normalized_options = MLTransposeOptions::Create(); + normalized_options->setPermutation(permutation); + + return BuildUnaryOperator(this, MLOperator::OperatorKind::kTranspose, input, + output_dimensions, input->Type(), normalized_options, + exception_state); +} + +MLOperand* MLGraphBuilder::triangularMatrix( + const MLOperand* input, + const MLTriangularMatrixOptions* options, + ExceptionState& exception_state) { + // TODO:::COMPLETE + return nullptr; +} + +MLOperand* MLGraphBuilder::unsqueeze( + const MLOperand* input, + const MLSqueezeOptions* options, + ExceptionState& exception_state) { + const auto& input_dimensions = input->Dimensions(); + const wtf_size_t input_rank = input_dimensions.size(); + + // Verify all axes are within bounds and not duplicated. + // Axes are allowed in any order, but insertion wants them in ascending + // order. So rearrange them. + Vector ordered_axes = options->getAxesOr({}); + std::sort(ordered_axes.begin(), ordered_axes.end()); + + uint32_t axes_mask = 0x00000000; // Insert no axes by default, if none passed. + if (!ValidateAxes(ordered_axes, input_rank + ordered_axes.size(), "unsqueeze", + exception_state)) { + return nullptr; + } + if (!ValidateAxesMask(options->axes(), + "unsqueeze", + exception_state, + /*out*/ axes_mask)) { + return nullptr; + } + + // Insert dimensions of size 1 into the output. + Vector output_dimensions = input_dimensions; + for (uint32_t axis : ordered_axes) + { + output_dimensions.insert(axis, 1u); + } + + // Resolve unsqueeze into a reshape operator. + return BuildUnaryOperator(this, MLOperator::OperatorKind::kReshape, input, + output_dimensions, input->Type(), /*options*/nullptr, + exception_state); +} + +HeapVector> MLGraphBuilder::gru(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + uint32_t steps, + uint32_t hidden_size, + const MLGruOptions* options, + ExceptionState& exception_state) { + // TODO:::Implement + return {}; +} + +MLOperand* MLGraphBuilder::gruCell(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + const MLOperand* hidden_state, + uint32_t hidden_size, + const MLGruCellOptions* options, + ExceptionState& exception_state) { + // TODO:::Implement + return nullptr; +} + +HeapVector> MLGraphBuilder::lstm(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + uint32_t steps, + uint32_t hidden_size, + const MLLstmOptions* options, + ExceptionState& exception_state) { + // TODO:::Implement + return {}; +} + +MLOperand* MLGraphBuilder::lstmCell(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + const MLOperand* hidden_state, + const MLOperand* cell_state, + uint32_t hidden_size, + const MLLstmCellOptions* options, + ExceptionState& exception_state) { + // TODO:::Implement + return nullptr; +} + +MLOperand* MLGraphBuilder::conv2dInteger(const MLOperand* input, + const MLOperand* input_zero_point, + const MLOperand* filter, + const MLOperand* filter_zero_point, + const MLConv2dIntegerOptions* options, + ExceptionState& exception_state) { + // Unify the two convolutions which only differ in direction + // (forward vs backward) with a common internal representation + // that the backends more easily read. + MLConvOptionsInternal* internal_options = MLConvOptionsInternal::Create(); + internal_options->setAutoPad(options->autoPad()); + internal_options->setGroups(options->groups()); + internal_options->setInputLayout(options->inputLayout()); + + if (options->hasPadding()) + { + internal_options->setPadding(options->padding()); + } + if (options->hasStrides()) + { + internal_options->setStrides(options->strides()); + } + if (options->hasDilations()) + { + internal_options->setDilations(options->dilations()); + } + // Integer convolution lacks bias and activation. + + V8MLConvFilterOperandLayoutInternal::Enum filter_layout; + static_assert(V8MLConv2dFilterOperandLayout::kEnumSize == 4); + switch (options->filterLayout().AsEnum()) { + default: + case V8MLConv2dFilterOperandLayout::Enum::kOihw: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kOihw; + break; + case V8MLConv2dFilterOperandLayout::Enum::kHwio: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kHwio; + break; + case V8MLConv2dFilterOperandLayout::Enum::kOhwi: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kOhwi; + break; + case V8MLConv2dFilterOperandLayout::Enum::kIhwo: + filter_layout = V8MLConvFilterOperandLayoutInternal::Enum::kIhwo; + break; + } + internal_options->setFilterLayout(filter_layout); + + return BuildConv2d(this, MLOperator::OperatorKind::kConv2dInteger, input, filter, + input_zero_point, filter_zero_point, internal_options, + exception_state); +} + +MLOperand* MLGraphBuilder::dequantizeLinear(const MLOperand* input, + const MLOperand* scale, + const MLOperand* zero_point, + ExceptionState& exception_state) { + if (!IsAllowedType(input->Type(), MLOperandTypeMask::kUint8) || + !IsAllowedType(zero_point->Type(), MLOperandTypeMask::kUint8)) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The input and zeroPoint data type must be a uint8."); + return nullptr; + } + if (!IsFloatingPointType(scale->Type())) { + exception_state.ThrowDOMException( + DOMExceptionCode::kDataError, + "The scale data type must be a floating point type."); + return nullptr; + } + + if (input->Type() != zero_point->Type()) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + "The input and zero point data types mismatch."); + return nullptr; + } + + // Ensure scale and bias are both broadcastable to the input. + const auto& input_dimensions = input->Dimensions(); + + if (!ValidateUnidirectionalBroadcastability( + scale->Dimensions(), input_dimensions, "dequantizeLinear", "scale", + "input", exception_state)) { + return nullptr; + } + if (!ValidateUnidirectionalBroadcastability( + zero_point->Dimensions(), input_dimensions, "dequantizeLinear", + "zeroPoint", "input", exception_state)) { + return nullptr; + } + + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kDequantizeLinear); + String error_message; + auto* output = MLOperand::ValidateAndCreateOutput(this, scale->Type(), + input_dimensions, + ml_operator, error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return nullptr; + } + ml_operator->Connect({input, scale, zero_point}, {output}); + return output; +} + +HeapVector> MLGraphBuilder::dynamicQuantizeLinear( + const MLOperand* input, + ExceptionState& exception_state) { + if (!IsFloatingPointType(input->Type())) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + "The input data type must be a floating point type."); + return {}; + } + + const auto& input_dimensions = input->Dimensions(); + + auto* ml_operator = MakeGarbageCollected( + this, MLOperator::OperatorKind::kDynamicQuantizeLinear); + String error_message; + Vector scalar_dimensions; + + auto* output = MLOperand::ValidateAndCreateOutput( + this, V8MLOperandType::Enum::kUint8, input_dimensions, ml_operator, + error_message); + if (!output) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return {}; + } + auto* output_scale = MLOperand::ValidateAndCreateOutput( + this, V8MLOperandType::Enum::kFloat32, scalar_dimensions, ml_operator, + error_message); + if (!output_scale) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return {}; + } + auto* output_zero_point = MLOperand::ValidateAndCreateOutput( + this, V8MLOperandType::Enum::kUint8, scalar_dimensions, ml_operator, + error_message); + if (!output_zero_point) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + return {}; + } + HeapVector> outputs = {output, output_scale, output_zero_point}; + + ml_operator->Connect({input}, {output, output_scale, output_zero_point}); + return outputs; +} + +MLOperand* MLGraphBuilder::matmulInteger(const MLOperand* a, + const MLOperand* a_zero_point, + const MLOperand* b, + const MLOperand* b_zero_point, + ExceptionState& exception_state) { + return BuildMatMul(this, MLOperator::OperatorKind::kMatmulInteger, a, b, + a_zero_point, b_zero_point, exception_state); +} + + ScriptPromise MLGraphBuilder::build(ScriptState* script_state, const MLNamedOperands& named_outputs, ExceptionState& exception_state) { @@ -1262,9 +3524,10 @@ ScriptPromise MLGraphBuilder::build(ScriptState* script_state, } #endif - // The Context is GPU device or low power preference, the graph is built by - // MojoGraph object. - if (GetContext()->GetDevicePreference() == V8MLDevicePreference::Enum::kGpu) { + // If the device preference is specifically kGpu/kNpu, or if using the + // automatic policy with a power preference set to high or low, the graph is + // built by MojoGraph object. + if (GetContext()->IsDedicatedHardwareDevice()) { if (ml_context_->IsWebnnMojoContextEnabled()) { MojoGraph::ValidateAndBuildAsync(ml_context_, named_outputs, resolver); } else { @@ -1297,9 +3560,10 @@ MLGraph* MLGraphBuilder::buildSync(ScriptState* script_state, } #endif - // The Context is GPU device or low power preference, the graph is built by - // MojoGraph object. - if (GetContext()->GetDevicePreference() == V8MLDevicePreference::Enum::kGpu) { + // If the device preference is specifically kGpu/kNpu, or if using the + // automatic policy with a power preference set to high or low, the graph is + // built by MojoGraph object. + if (GetContext()->IsDedicatedHardwareDevice()) { if (ml_context_->IsWebnnMojoContextEnabled()) { return MojoGraph::ValidateAndBuildSync(script_state, ml_context_, named_outputs, exception_state); @@ -1332,6 +3596,7 @@ void MLGraphBuilder::SortOperators( HeapVector>& sorted_operators) { HeapDeque> operators_to_do; HeapHashSet> operators_done; + HeapHashSet> visited_inputs; for (const auto& output : named_outputs) { operators_to_do.push_back(output.second->Operator()); } @@ -1353,7 +3618,11 @@ void MLGraphBuilder::SortOperators( // done set. for (const auto& input : op->Inputs()) { if (input->Kind() == MLOperand::kInput) { - inputs.push_back(input.Get()); + // Add the input if it is not visited. + if (!visited_inputs.Contains(input.Get())) { + inputs.push_back(input.Get()); + visited_inputs.insert(input.Get()); + } } else if (input->Kind() == MLOperand::kConstant) { constants.push_back(input.Get()); } diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h index a71e97c83ed9c2..715c1fdc4c16f0 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h @@ -8,6 +8,7 @@ #include "third_party/abseil-cpp/absl/types/optional.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_auto_pad.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_type.h" #include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h" #include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h" @@ -21,16 +22,41 @@ namespace blink { class ExceptionState; class MLContext; -class MLClampOptions; -class MLConv2dOptions; -class MLGemmOptions; class MLGraph; -class MLPool2dOptions; -class MLResample2dOptions; class MLOperand; class MLOperandDescriptor; class ScriptPromiseResolver; class ScriptPromise; +class V8UnionUnsignedLongOrUnsignedLongSequence; + +class MLClampOptions; +class MLConv2dOptions; +class MLConv2dIntegerOptions; +class MLConvTranspose2dOptions; +class MLGemmOptions; +class MLPool2dOptions; +class MLGatherOptions; +class MLResample2dOptions; +class MLArgMinMaxOptions; +class MLSqueezeOptions; +class MLSliceOptions; +class MLSplitOptions; +class MLTransposeOptions; +class MLPadOptions; +class MLInstanceNormalizationOptions; +class MLMeanVarianceNormalizationOptions; +class MLFillSequenceOptions; +class MLTriangularMatrixOptions; +class MLReduceOptions; +class MLEluOptions; +class MLHardSigmoidOptions; +class MLLeakyReluOptions; +class MLLinearOptions; +class MLSoftplusOptions; +class MLGruOptions; +class MLGruCellOptions; +class MLLstmOptions; +class MLLstmCellOptions; typedef HeapVector>> MLNamedOperands; @@ -86,6 +112,10 @@ class MODULES_EXPORT MLGraphBuilder : public ScriptWrappable { const MLOperand* filter, const MLConv2dOptions* options, ExceptionState& exception_state); + MLOperand* convTranspose2d(const MLOperand* input, + const MLOperand* filter, + const MLConvTranspose2dOptions* options, + ExceptionState& exception_state); // Element-wise binary operations MLOperand* add(const MLOperand* a, @@ -119,6 +149,9 @@ class MODULES_EXPORT MLGraphBuilder : public ScriptWrappable { MLOperand* averagePool2d(const MLOperand* input, const MLPool2dOptions* options, ExceptionState& exception_state); + MLOperand* l2Pool2d(const MLOperand* input, + const MLPool2dOptions* options, + ExceptionState& exception_state); MLOperand* maxPool2d(const MLOperand* input, const MLPool2dOptions* options, ExceptionState& exception_state); @@ -134,10 +167,215 @@ class MODULES_EXPORT MLGraphBuilder : public ScriptWrappable { const MLResample2dOptions* options, ExceptionState& exception_state); - MLOperand* softmax(const MLOperand* input, ExceptionState& exception_state); - MLOperand* sigmoid(const MLOperand* input, ExceptionState& exception_state); MLOperator* sigmoid(ExceptionState& exception_state); + MLOperand* elu(const MLOperand* input, + const MLEluOptions* options, + ExceptionState& exception_state); + MLOperator* elu(const MLEluOptions* options, ExceptionState& exception_state); + MLOperand* hardSigmoid(const MLOperand* input, + const MLHardSigmoidOptions* options, + ExceptionState& exception_state); + MLOperator* hardSigmoid(const MLHardSigmoidOptions* options, + ExceptionState& exception_state); + MLOperand* leakyRelu(const MLOperand* input, + const MLLeakyReluOptions* options, + ExceptionState& exception_state); + MLOperator* leakyRelu(const MLLeakyReluOptions* options, + ExceptionState& exception_state); + MLOperand* linear(const MLOperand* input, + const MLLinearOptions* options, + ExceptionState& exception_state); + MLOperator* linear(const MLLinearOptions* options, + ExceptionState& exception_state); + MLOperand* prelu(const MLOperand* input, + const MLOperand* slope, + ExceptionState& exception_state); + MLOperand* softplus(const MLOperand* input, + const MLSoftplusOptions* options, + ExceptionState& exception_state); + MLOperator* softplus(const MLSoftplusOptions* options, + ExceptionState& exception_state); + MLOperand* softsign(const MLOperand* input, ExceptionState& exception_state); + MLOperator* softsign(ExceptionState& exception_state); + MLOperand* softmax(const MLOperand* input, ExceptionState& exception_state); + MLOperator* softmax(ExceptionState& exception_state); + MLOperand* tanh(const MLOperand* input, ExceptionState& exception_state); + MLOperator* tanh(ExceptionState& exception_state); + + MLOperand* argMax(const MLOperand* input, + const MLArgMinMaxOptions* options, + ExceptionState& exception_state); + MLOperand* argMin(const MLOperand* input, + const MLArgMinMaxOptions* options, + ExceptionState& exception_state); + MLOperand* cast(const MLOperand* input, + V8MLOperandType data_type, + ExceptionState& exception_state); + MLOperand* concat(const HeapVector>& inputs, + uint32_t axis, + ExceptionState& exception_state); + MLOperand* expand(const MLOperand* input, + const Vector& new_shape, + ExceptionState& exception_state); + MLOperand* abs(const MLOperand* input, ExceptionState& exception_state); + MLOperand* neg(const MLOperand* input, ExceptionState& exception_state); + MLOperand* cos(const MLOperand* input, ExceptionState& exception_state); + MLOperand* equal(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state); + MLOperand* erf(const MLOperand* input, ExceptionState& exception_state); + MLOperand* exp(const MLOperand* input, ExceptionState& exception_state); + MLOperand* log(const MLOperand* input, ExceptionState& exception_state); + MLOperand* floor(const MLOperand* input, ExceptionState& exception_state); + MLOperand* ceil(const MLOperand* input, ExceptionState& exception_state); + MLOperand* reciprocal(const MLOperand* input, ExceptionState& exception_state); + MLOperand* logicalNot(const MLOperand* input, ExceptionState& exception_state); + MLOperand* flattenTo2d(const MLOperand* input, + uint32_t axis, + ExceptionState& exception_state); + MLOperand* gather(const MLOperand* input, + const MLOperand* indices, + const MLGatherOptions* options, + ExceptionState& exception_state); + MLOperand* greater(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state); + MLOperand* identity(const MLOperand* input, ExceptionState& exception_state); + MLOperand* instanceNormalization( + const MLOperand* input, + const MLInstanceNormalizationOptions* options, + ExceptionState& exception_state); + MLOperand* meanVarianceNormalization( + const MLOperand* input, + const MLMeanVarianceNormalizationOptions* options, + ExceptionState& exception_state); + MLOperand* lesser(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state); + MLOperand* matmul(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state); + MLOperand* pad(const MLOperand* input, + const Vector& beginningPadding, + const Vector& endingPadding, + const MLPadOptions* options, + ExceptionState& exception_state); + MLOperand* pow(const MLOperand* a, + const MLOperand* b, + ExceptionState& exception_state); + MLOperand* fillSequence(V8MLOperandType output_data_type, + const Vector& output_shape, + const MLFillSequenceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceL1(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceL2(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceLogSum(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceLogSumExp(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceMax(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceMean(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceMin(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceProduct(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceSum(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + MLOperand* reduceSumSquare(const MLOperand* input, + const MLReduceOptions* options, + ExceptionState& exception_state); + Vector shape(const MLOperand* input, + ExceptionState& exception_state); + MLOperand* sin(const MLOperand* input, ExceptionState& exception_state); + MLOperand* slice(const MLOperand* input, + const Vector& starts, + const Vector& sizes, + ExceptionState& exception_state); + HeapVector> split(const MLOperand* input, + blink::V8UnionUnsignedLongOrUnsignedLongSequence* splits, + const MLSplitOptions* options, + ExceptionState& exception_state); + MLOperand* sqrt(const MLOperand* input, ExceptionState& exception_state); + MLOperand* tan(const MLOperand* input, ExceptionState& exception_state); + MLOperand* transpose(const MLOperand* input, + const MLTransposeOptions* options, + ExceptionState& exception_state); + MLOperand* triangularMatrix(const MLOperand* input, + const MLTriangularMatrixOptions* options, + ExceptionState& exception_state); + MLOperand* squeeze(const MLOperand* input, + const MLSqueezeOptions* options, + ExceptionState& exception_state); + MLOperand* unsqueeze(const MLOperand* input, + const MLSqueezeOptions* options, + ExceptionState& exception_state); + + MLOperand* elementwiseIf(const MLOperand* condition, + const MLOperand* true_value, + const MLOperand* false_value, + ExceptionState& exception_state); + + HeapVector> gru(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + uint32_t steps, + uint32_t hidden_size, + const MLGruOptions* options, + ExceptionState& exception_state); + MLOperand* gruCell(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + const MLOperand* hidden_state, + uint32_t hidden_size, + const MLGruCellOptions* options, + ExceptionState& exception_state); + HeapVector> lstm(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + uint32_t steps, + uint32_t hidden_size, + const MLLstmOptions* options, + ExceptionState& exception_state); + MLOperand* lstmCell(const MLOperand* input, + const MLOperand* weight, + const MLOperand* recurrent_weight, + const MLOperand* hidden_state, + const MLOperand* cell_state, + uint32_t hidden_size, + const MLLstmCellOptions* options, + ExceptionState& exception_state); + MLOperand* conv2dInteger(const MLOperand* input, + const MLOperand* inputZeroPoint, + const MLOperand* filter, + const MLOperand* filterZeroPoint, + const MLConv2dIntegerOptions* options, + ExceptionState& exception_state); + MLOperand* dequantizeLinear(const MLOperand* input, + const MLOperand* scale, + const MLOperand* zeroPoint, + ExceptionState& exception_state); + HeapVector> dynamicQuantizeLinear( + const MLOperand* input, + ExceptionState& exception_state); + MLOperand* matmulInteger(const MLOperand* a, + const MLOperand* a_zero_point, + const MLOperand* b, + const MLOperand* b_zero_point, + ExceptionState& exception_state); ScriptPromise build(ScriptState* script_state, const MLNamedOperands& named_outputs, diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl index 7af04ecb327335..a37fa515d9fba5 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl @@ -12,6 +12,8 @@ enum MLConv2dFilterOperandLayout { "oihw", "hwio", "ohwi", "ihwo" }; enum MLAutoPad { "explicit", "same-upper", "same-lower" }; +enum MLTriangularPart { "upper", "lower" }; + dictionary MLConv2dOptions { sequence<[EnforceRange] unsigned long> padding; sequence<[EnforceRange] unsigned long> strides; @@ -24,6 +26,33 @@ dictionary MLConv2dOptions { MLOperator activation; }; +dictionary MLConv2dIntegerOptions { + sequence<[EnforceRange] unsigned long> padding; + sequence<[EnforceRange] unsigned long> strides; + sequence<[EnforceRange] unsigned long> dilations; + MLAutoPad autoPad = "explicit"; + [EnforceRange] unsigned long groups = 1; + MLInputOperandLayout inputLayout = "nchw"; + MLConv2dFilterOperandLayout filterLayout = "oihw"; + // No bias or activation function. +}; + +enum MLConvTranspose2dFilterOperandLayout { "iohw", "hwoi", "ohwi" }; + +dictionary MLConvTranspose2dOptions { + sequence<[EnforceRange] unsigned long> padding; + sequence<[EnforceRange] unsigned long> strides; + sequence<[EnforceRange] unsigned long> dilations; + sequence<[EnforceRange] unsigned long> outputPadding; + sequence<[EnforceRange] unsigned long> outputSizes; + MLAutoPad autoPad = "explicit"; + unsigned long groups = 1; + MLInputOperandLayout inputLayout = "nchw"; + MLConvTranspose2dFilterOperandLayout filterLayout = "iohw"; + MLOperand bias; + MLOperator activation; +}; + dictionary MLGemmOptions { MLOperand c; float alpha = 1.0; @@ -59,7 +88,167 @@ dictionary MLResample2dOptions { MLInterpolationMode mode = "nearest-neighbor"; sequence scales; sequence<[EnforceRange] unsigned long> sizes; - sequence axes; + sequence<[EnforceRange] unsigned long> axes; +}; + +dictionary MLInstanceNormalizationOptions { + MLOperand scale; + MLOperand bias; + float epsilon = 1e-5; + MLInputOperandLayout layout = "nchw"; +}; + +dictionary MLMeanVarianceNormalizationOptions { + MLOperand mean; // optional + MLOperand variance; // optional + MLOperand scale; + MLOperand bias; + float epsilon = 1e-5; + sequence<[EnforceRange] unsigned long> axes; +}; + +dictionary MLTransposeOptions { + sequence<[EnforceRange] unsigned long> permutation; +}; + +dictionary MLSqueezeOptions { + sequence<[EnforceRange] unsigned long> axes; +}; + +dictionary MLArgMinMaxOptions { + unsigned long axis = 0; + boolean keepDimensions = false; + boolean selectLastIndex = false; +}; + +dictionary MLConcatOptionsInternal { + unsigned long axis = 0; +}; + +dictionary MLGatherOptions { + unsigned long axis = 0; +}; + +enum MLPaddingMode { + "constant", + "edge", + "reflection", + "symmetric" +}; + +dictionary MLPadOptions { + MLPaddingMode mode = "constant"; + float value = 0; +}; + +dictionary MLPadOptionsInternal { + MLPaddingMode mode = "constant"; + float value = 0; + sequence<[EnforceRange] unsigned long> beginningPadding; + sequence<[EnforceRange] unsigned long> endingPadding; +}; + +// diagonalDelta is a horizontal shift. So positive delta means that for an the upper triangular matrix, +// the value mask is shifted rightward. +// Related: +// https://numpy.org/doc/stable/reference/generated/numpy.tril.html#numpy.tril +// https://numpy.org/doc/stable/reference/generated/numpy.triu.html#numpy.triu +// https://www.tensorflow.org/probability/api_docs/python/tfp/math/fill_triangular +dictionary MLTriangularMatrixOptions { + MLTriangularPart triangularPart; + long diagonalDelta; +}; + +dictionary MLReduceOptions { + sequence axes; + boolean keepDimensions = false; +}; + +dictionary MLFillSequenceOptions { + float start = 0; + float delta = 1; +}; + +dictionary MLSplitOptions { + unsigned long axis = 0; +}; + +dictionary MLEluOptions { + float alpha = 1; +}; + +dictionary MLHardSigmoidOptions { + float alpha = 0.2; + float beta = 0.5; +}; + +dictionary MLLeakyReluOptions { + float alpha = 0.01; +}; + +dictionary MLLinearOptions { + float alpha = 1; + float beta = 0; +}; + +dictionary MLSoftplusOptions { + float steepness = 1; +}; + +enum MLGruWeightLayout { + "zrn", // update-reset-new gate ordering + "rzn" // reset-update-new gate ordering +}; + +enum MLRecurrentNetworkDirection { + "forward", + "backward", + "both" +}; + +dictionary MLGruOptions { + MLOperand bias; + MLOperand recurrentBias; + MLOperand initialHiddenState; + boolean resetAfter = true; + boolean returnSequence = false; + MLRecurrentNetworkDirection direction = "forward"; + MLGruWeightLayout layout = "zrn"; + sequence activations; +}; + +dictionary MLGruCellOptions { + MLOperand bias; + MLOperand recurrentBias; + boolean resetAfter = true; + MLGruWeightLayout layout = "zrn"; + sequence activations; +}; + +enum MLLstmWeightLayout { + "iofg", // input-output-forget-cell gate ordering + "ifgo" // input-forget-cell-output gate ordering +}; + +dictionary MLLstmOptions { + MLOperand bias; + MLOperand recurrentBias; + MLOperand peepholeWeight; + MLOperand initialHiddenState; + MLOperand initialCellState; + boolean returnSequence = false; + MLRecurrentNetworkDirection direction = "forward"; + MLLstmWeightLayout layout = "iofg"; + sequence activations; +}; + + +dictionary MLLstmCellOptions { + MLOperand bias; + MLOperand recurrentBias; + MLOperand peepholeWeight; + MLLstmWeightLayout layout = "iofg"; + sequence activations; }; [ @@ -72,42 +261,225 @@ dictionary MLResample2dOptions { [RaisesException] MLOperand constant(MLOperandDescriptor desc, MLBufferView bufferView); - [RaisesException] MLOperand clamp(MLOperand input, optional MLClampOptions options = {}); - [RaisesException] MLOperator clamp(optional MLClampOptions options = {}); - + // Dot product operations. [RaisesException] MLOperand conv2d(MLOperand input, MLOperand filter, optional MLConv2dOptions options = {}); + [RaisesException] MLOperand convTranspose2d(MLOperand input, MLOperand filter, optional MLConvTranspose2dOptions options = {}); + [RaisesException] MLOperand gemm(MLOperand a, MLOperand b, optional MLGemmOptions options = {}); + [RaisesException] MLOperand matmul(MLOperand a, MLOperand b); - // Element-wise binary operations + // Elementwise binary operations [RaisesException] MLOperand add(MLOperand a, MLOperand b); [RaisesException] MLOperand sub(MLOperand a, MLOperand b); [RaisesException] MLOperand mul(MLOperand a, MLOperand b); [RaisesException] MLOperand div(MLOperand a, MLOperand b); [RaisesException] MLOperand max(MLOperand a, MLOperand b); [RaisesException] MLOperand min(MLOperand a, MLOperand b); + [RaisesException] MLOperand pow(MLOperand a, MLOperand b); - [RaisesException] MLOperand gemm(MLOperand a, MLOperand b, optional MLGemmOptions options = {}); + // Elementwise binary logical comparison operations + [RaisesException] MLOperand equal(MLOperand a, MLOperand b); + [RaisesException] MLOperand greater(MLOperand a, MLOperand b); + [RaisesException] MLOperand lesser(MLOperand a, MLOperand b); - [RaisesException] MLOperand hardSwish(MLOperand x); + // Elementwise unary activation operations. + // TODO: Rename MLOperator -> MLActivation per spec update. + [RaisesException] MLOperand relu(MLOperand input); + [RaisesException] MLOperator relu(); + + [RaisesException] MLOperand elu(MLOperand input, optional MLEluOptions options = {}); + [RaisesException] MLOperator elu(optional MLEluOptions options = {}); + + [RaisesException] MLOperand prelu(MLOperand input, MLOperand slope); + + [RaisesException] MLOperand leakyRelu(MLOperand input, optional MLLeakyReluOptions options = {}); + [RaisesException] MLOperator leakyRelu(optional MLLeakyReluOptions options = {}); + + [RaisesException] MLOperand clamp(MLOperand input, optional MLClampOptions options = {}); + [RaisesException] MLOperator clamp(optional MLClampOptions options = {}); + + [RaisesException] MLOperand sigmoid(MLOperand input); + [RaisesException] MLOperator sigmoid(); + + [RaisesException] MLOperand hardSigmoid(MLOperand input, optional MLHardSigmoidOptions options = {}); + [RaisesException] MLOperator hardSigmoid(optional MLHardSigmoidOptions options = {}); + + [RaisesException] MLOperand hardSwish(MLOperand input); [RaisesException] MLOperator hardSwish(); + [RaisesException] MLOperand linear(MLOperand input, optional MLLinearOptions options = {}); + [RaisesException] MLOperator linear(optional MLLinearOptions options = {}); + + [RaisesException] MLOperand softplus(MLOperand input, optional MLSoftplusOptions options = {}); + [RaisesException] MLOperator softplus(optional MLSoftplusOptions options = {}); + + [RaisesException] MLOperand softsign(MLOperand input); + [RaisesException] MLOperator softsign(); + + [RaisesException] MLOperand softmax(MLOperand input); + [RaisesException] MLOperator softmax(); + + [RaisesException] MLOperand tanh(MLOperand input); + [RaisesException] MLOperator tanh(); + // Pooling operations [RaisesException] MLOperand averagePool2d(MLOperand input, optional MLPool2dOptions options = {}); + [RaisesException] MLOperand l2Pool2d(MLOperand input, optional MLPool2dOptions options = {}); [RaisesException] MLOperand maxPool2d(MLOperand input, optional MLPool2dOptions options = {}); - [RaisesException] MLOperand relu(MLOperand input); - [RaisesException] MLOperator relu(); + // Elementwise unary operations + [RaisesException] MLOperand identity(MLOperand input); + [RaisesException] MLOperand abs(MLOperand input); + [RaisesException] MLOperand neg(MLOperand input); + [RaisesException] MLOperand exp(MLOperand input); + [RaisesException] MLOperand log(MLOperand input); + [RaisesException] MLOperand sqrt(MLOperand input); + [RaisesException] MLOperand sin(MLOperand input); + [RaisesException] MLOperand cos(MLOperand input); + [RaisesException] MLOperand tan(MLOperand input); + [RaisesException] MLOperand erf(MLOperand input); + [RaisesException] MLOperand floor(MLOperand input); + [RaisesException] MLOperand ceil(MLOperand input); + [RaisesException] MLOperand reciprocal(MLOperand input); + [RaisesException] MLOperand logicalNot(MLOperand input); + + // Trinary elementwise operations + [RaisesException] MLOperand elementwiseIf(MLOperand condition, MLOperand trueValue, MLOperand falseValue); + // Shape dimension reinterpreting operations. + // TODO:::UPDATE reshape to unsigned long, and remove any null/-1 behavior after + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-reshape. [RaisesException] MLOperand reshape(MLOperand input, sequence newShape); + [RaisesException] MLOperand squeeze(MLOperand input, optional MLSqueezeOptions options = {}); + [RaisesException] MLOperand unsqueeze(MLOperand input, MLSqueezeOptions options); + [RaisesException] MLOperand flattenTo2d(MLOperand input, unsigned long axis); + // Shape modification operations. + [RaisesException] MLOperand concat(sequence inputs, unsigned long axis); + [RaisesException] MLOperand slice(MLOperand input, sequence starts, sequence sizes); + [RaisesException] sequence split(MLOperand input, (unsigned long or sequence) splits, optional MLSplitOptions options = {}); + [RaisesException] MLOperand transpose(MLOperand input, optional MLTransposeOptions options = {}); + [RaisesException] MLOperand pad(MLOperand input, sequence beginningPadding, sequence endingPadding, optional MLPadOptions options = {}); + [RaisesException] MLOperand expand(MLOperand input, sequence newShape); + [RaisesException] MLOperand gather(MLOperand input, MLOperand indices, optional MLGatherOptions options = {}); [RaisesException] MLOperand resample2d(MLOperand input, optional MLResample2dOptions options = {}); - [RaisesException] MLOperand softmax(MLOperand input); + // Reduction operations + [RaisesException] MLOperand reduceL1(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceL2(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceLogSum(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceLogSumExp(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceMax(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceMean(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceMin(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceProduct(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceSum(MLOperand input, optional MLReduceOptions options = {}); + [RaisesException] MLOperand reduceSumSquare(MLOperand input, optional MLReduceOptions options = {}); - [RaisesException] MLOperand sigmoid(MLOperand input); - [RaisesException] MLOperator sigmoid(); + // Normalization operations + [RaisesException] MLOperand instanceNormalization(MLOperand input, optional MLInstanceNormalizationOptions options = {}); + [RaisesException] MLOperand meanVarianceNormalization(MLOperand input, optional MLMeanVarianceNormalizationOptions options = {}); + + // Miscellaneous operations + [RaisesException] MLOperand argMax(MLOperand input, optional MLArgMinMaxOptions options = {}); + [RaisesException] MLOperand argMin(MLOperand input, optional MLArgMinMaxOptions options = {}); + [RaisesException] MLOperand cast(MLOperand input, MLOperandType operandType); + [RaisesException] MLOperand fillSequence(MLOperandType operandType, sequence outputShape, optional MLFillSequenceOptions options = {}); + [RaisesException] MLOperand triangularMatrix(MLOperand input, optional MLTriangularMatrixOptions options = {}); + + // Iterative GEMM operations + [RaisesException] sequence gru( + MLOperand input, + MLOperand weight, + MLOperand recurrentWeight, + unsigned long steps, + unsigned long hiddenSize, + optional MLGruOptions options = {} + ); + + [RaisesException] MLOperand gruCell( + MLOperand input, + MLOperand weight, + MLOperand recurrentWeight, + MLOperand hiddenState, + unsigned long hiddenSize, + optional MLGruCellOptions options = {} + ); + + [RaisesException] sequence lstm( + MLOperand input, + MLOperand weight, + MLOperand recurrentWeight, + unsigned long steps, + unsigned long hiddenSize, + optional MLLstmOptions options = {} + ); + + [RaisesException] MLOperand lstmCell( + MLOperand input, + MLOperand weight, + MLOperand recurrentWeight, + MLOperand hiddenState, + MLOperand cellState, + unsigned long hiddenSize, + optional MLLstmCellOptions options = {} + ); + + [RaisesException] MLOperand conv2dInteger(MLOperand input, MLOperand inputZeroPoint, MLOperand filter, MLOperand filterZeroPoint, optional MLConv2dIntegerOptions options = {}); + + [RaisesException] MLOperand matmulInteger(MLOperand a, MLOperand aZeroPoint, MLOperand b, MLOperand bZeroPoint); + + [RaisesException] MLOperand dequantizeLinear(MLOperand input, MLOperand scale, MLOperand zeroPoint); + + [RaisesException] sequence dynamicQuantizeLinear(MLOperand input); + + // Shape retrieval. + [RaisesException] sequence shape(MLOperand input); [CallWith=ScriptState, RaisesException] Promise build(MLNamedOperands outputs); [RuntimeEnabled=MachineLearningNeuralNetwork, CallWith=ScriptState, Exposed=DedicatedWorker, RaisesException] MLGraph buildSync(MLNamedOperands outputs); }; + +// Complete representation for ferrying from IDL to Mojo. +// The parameters have been folded into the options. +dictionary MLFloatParameterOptionsInternal { + float firstParameter = 0.0; + float secondParameter = 0.0; +}; + +// Complete representation for ferrying from IDL to Mojo. +// The parameters have been folded into the options. +dictionary MLSliceOptionsInternal { + sequence<[EnforceRange] unsigned long> starts; + sequence<[EnforceRange] unsigned long> sizes; +}; + +// Complete representation for ferrying from IDL to Mojo. +// The parameters have been folded into the options. +dictionary MLSplitOptionsInternal { + sequence<[EnforceRange] unsigned long> splits; + unsigned long axis = 0; +}; + +enum MLConvFilterOperandLayoutInternal { "oihw", "iohw", "hwoi", "hwio", "ohwi", "ihwo" }; + +// Complete representation for ferrying from IDL to Mojo. +// The differences between MLConvTranspose2dOptions and +// MLConvTranspose2dOptions and the WebNN API level complicate +// internal processing, along with MLConv2dFilterOperandLayout +// vs MLConvTranspose2dFilterOperandLayout which are nearly +// identical. So unify them. +dictionary MLConvOptionsInternal { + sequence<[EnforceRange] unsigned long> padding; + sequence<[EnforceRange] unsigned long> strides; + sequence<[EnforceRange] unsigned long> dilations; + sequence<[EnforceRange] unsigned long> outputPadding; + sequence<[EnforceRange] unsigned long> outputSizes; + MLAutoPad autoPad = "explicit"; + unsigned long groups = 1; + MLInputOperandLayout inputLayout = "nchw"; + MLConvFilterOperandLayoutInternal filterLayout = "oihw"; + MLOperand bias; + MLOperator activation; +}; diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc index 304914fbf17c29..9ff1dc5083a6e7 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc @@ -17,6 +17,7 @@ #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_options_internal.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h" #include "third_party/blink/renderer/core/dom/dom_exception.h" #include "third_party/blink/renderer/modules/ml/ml.h" @@ -207,6 +208,8 @@ xnn_datatype GetXnnDataType(V8MLOperandType::Enum operand_type) { case V8MLOperandType::Enum::kUint32: case V8MLOperandType::Enum::kInt8: case V8MLOperandType::Enum::kUint8: + case V8MLOperandType::Enum::kInt64: + case V8MLOperandType::Enum::kUint64: // TODO(crbug.com/1273291): Support the quantized integer types that is a // WebNN v2 feature tracked by: // https://github.com/webmachinelearning/webnn/issues/128. @@ -467,8 +470,8 @@ xnn_status DefineXnnNodeForConv2d(xnn_subgraph_t subgraph, const uint32_t output_id = GetOperatorOutputValueId(conv2d, operand_value_id_map); - const MLConv2dOptions* options = - static_cast(conv2d->Options()); + const MLConvOptionsInternal* options = + static_cast(conv2d->Options()); // Set strides of XNNPACK conv2d, default to 1. const Vector default_strides({1, 1}); @@ -511,7 +514,7 @@ xnn_status DefineXnnNodeForConv2d(xnn_subgraph_t subgraph, // TODO(crbug.com/1273291): support other layouts by transposing the // filter operand. if (options->filterLayout().AsEnum() != - V8MLConv2dFilterOperandLayout::Enum::kOhwi) { + V8MLConvFilterOperandLayoutInternal::Enum::kOhwi) { error_message = String::Format("The filter layout %s is not supported.", options->filterLayout().AsCStr()); return xnn_status_unsupported_parameter; @@ -523,7 +526,7 @@ xnn_status DefineXnnNodeForConv2d(xnn_subgraph_t subgraph, // TODO(crbug.com/1273291): support other layouts by transposing the // filter operand. if (options->filterLayout().AsEnum() != - V8MLConv2dFilterOperandLayout::Enum::kIhwo) { + V8MLConvFilterOperandLayoutInternal::Enum::kIhwo) { error_message = String::Format("The filter layout %s is not supported.", options->filterLayout().AsCStr()); return xnn_status_unsupported_parameter; diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc b/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc index 856014af7b628d..849b9e2186c95b 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_operand.cc @@ -28,18 +28,21 @@ size_t GetBytesPerElement(V8MLOperandType::Enum operand_type) { return sizeof(int8_t); case V8MLOperandType::Enum::kUint8: return sizeof(uint8_t); + case V8MLOperandType::Enum::kInt64: + return sizeof(int64_t); + case V8MLOperandType::Enum::kUint64: + return sizeof(uint64_t); } } absl::optional ValidateAndCalculateElementsNumber( const Vector& dimensions, String& error_message) { - if (dimensions.empty()) { - error_message = "The dimensions is empty."; - return absl::nullopt; - } + // Note that empty dimensions are completely legal and indicate a scalar + // value of 1 element. base::CheckedNumeric checked_number_of_elements = 1; for (auto& d : dimensions) { + // Note that zero-sized dimensions are legal and should be treated as nops. if (d == 0) { error_message = "All dimensions should be positive."; return absl::nullopt; @@ -87,6 +90,10 @@ DOMArrayBufferView::ViewType GetArrayBufferViewType( return DOMArrayBufferView::ViewType::kTypeInt32; case V8MLOperandType::Enum::kUint32: return DOMArrayBufferView::ViewType::kTypeUint32; + case V8MLOperandType::Enum::kInt64: + return DOMArrayBufferView::ViewType::kTypeBigInt64; + case V8MLOperandType::Enum::kUint64: + return DOMArrayBufferView::ViewType::kTypeBigUint64; case V8MLOperandType::Enum::kInt8: return DOMArrayBufferView::ViewType::kTypeInt8; case V8MLOperandType::Enum::kUint8: diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_operand_descriptor.idl b/third_party/blink/renderer/modules/ml/webnn/ml_operand_descriptor.idl index 124304800ebec5..7a56afb0d5b514 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_operand_descriptor.idl +++ b/third_party/blink/renderer/modules/ml/webnn/ml_operand_descriptor.idl @@ -10,7 +10,9 @@ enum MLOperandType { "int32", "uint32", "int8", - "uint8" + "uint8", + "int64", + "uint64", }; dictionary MLOperandDescriptor { diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_operator.cc b/third_party/blink/renderer/modules/ml/webnn/ml_operator.cc index b78695407edd38..32f79d33723a1e 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_operator.cc +++ b/third_party/blink/renderer/modules/ml/webnn/ml_operator.cc @@ -7,15 +7,22 @@ #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h" +// TODO::: +#pragma optimize("", off) + namespace blink { // static String MLOperator::OperatorKindToString(MLOperator::OperatorKind kind) { + static_assert(MLOperator::OperatorKind::kTotal == MLOperator::OperatorKind(83)); + switch (kind) { case MLOperator::OperatorKind::kClamp: return "clamp"; case MLOperator::OperatorKind::kConv2d: return "conv2d"; + case MLOperator::OperatorKind::kConvTranspose2d: + return "convTranspose2d"; case MLOperator::OperatorKind::kAdd: return "add"; case MLOperator::OperatorKind::kSub: @@ -28,6 +35,14 @@ String MLOperator::OperatorKindToString(MLOperator::OperatorKind kind) { return "max"; case MLOperator::OperatorKind::kMin: return "min"; + case MLOperator::OperatorKind::kFloor: + return "floor"; + case MLOperator::OperatorKind::kCeil: + return "ceil"; + case MLOperator::OperatorKind::kReciprocal: + return "reciprocal"; + case MLOperator::OperatorKind::kLogicalNot: + return "logicalNot"; case MLOperator::OperatorKind::kGemm: return "gemm"; case MLOperator::OperatorKind::kHardSwish: @@ -46,6 +61,128 @@ String MLOperator::OperatorKindToString(MLOperator::OperatorKind kind) { return "softmax"; case MLOperator::OperatorKind::kSigmoid: return "sigmoid"; + case MLOperator::OperatorKind::kArgMax: + return "argMax"; + case MLOperator::OperatorKind::kArgMin: + return "argMin"; + case MLOperator::OperatorKind::kCast: + return "cast"; + case MLOperator::OperatorKind::kConcat: + return "concat"; + case MLOperator::OperatorKind::kExpand: + return "expand"; + case MLOperator::OperatorKind::kCos: + return "cos"; + case MLOperator::OperatorKind::kEqual: + return "equal"; + case MLOperator::OperatorKind::kErf: + return "erf"; + case MLOperator::OperatorKind::kExp: + return "exp"; + case MLOperator::OperatorKind::kFlattenTo2d: + return "flattenTo2d"; + case MLOperator::OperatorKind::kGather: + return "gather"; + case MLOperator::OperatorKind::kGreater: + return "greater"; + case MLOperator::OperatorKind::kLesser: + return "lesser"; + case MLOperator::OperatorKind::kIdentity: + return "identity"; + case MLOperator::OperatorKind::kInstanceNormalization: + return "instanceNormalization"; + case MLOperator::OperatorKind::kMeanVarianceNormalization: + return "meanVarianceNormalization"; + case MLOperator::OperatorKind::kMatmul: + return "matmul"; + case MLOperator::OperatorKind::kPad: + return "pad"; + case MLOperator::OperatorKind::kPow: + return "pow"; + case MLOperator::OperatorKind::kFillSequence: + return "fillSequence"; + case MLOperator::OperatorKind::kReduceL1: + return "reduceL1"; + case MLOperator::OperatorKind::kReduceL2: + return "reduceL2"; + case MLOperator::OperatorKind::kReduceLogSum: + return "reduceLogSum"; + case MLOperator::OperatorKind::kReduceLogSumExp: + return "reduceLogSumExp"; + case MLOperator::OperatorKind::kReduceMax: + return "reduceMax"; + case MLOperator::OperatorKind::kReduceMean: + return "reduceMean"; + case MLOperator::OperatorKind::kReduceMin: + return "reduceMin"; + case MLOperator::OperatorKind::kReduceProduct: + return "reduceProduct"; + case MLOperator::OperatorKind::kReduceSum: + return "reduceSum"; + case MLOperator::OperatorKind::kReduceSumSquare: + return "reduceSumSquare"; + case MLOperator::OperatorKind::kSin: + return "sin"; + case MLOperator::OperatorKind::kSlice: + return "slice"; + case MLOperator::OperatorKind::kSplit: + return "split"; + case MLOperator::OperatorKind::kSqrt: + return "sqrt"; + case MLOperator::OperatorKind::kTranspose: + return "transpose"; + case MLOperator::OperatorKind::kTriangularMatrix: + return "triangularMatrix"; + case MLOperator::OperatorKind::kTan: + return "tan"; + case MLOperator::OperatorKind::kSqueeze: + return "squeeze"; + case MLOperator::OperatorKind::kUnsqueeze: + return "unsqueeze"; + case MLOperator::OperatorKind::kElementWiseIf: + return "elementwiseIf"; + case MLOperator::OperatorKind::kElu: + return "elu"; + case MLOperator::OperatorKind::kPrelu: + return "prelu"; + case MLOperator::OperatorKind::kLeakyRelu: + return "leakyRelu"; + case MLOperator::OperatorKind::kHardSigmoid: + return "hardSigmoid"; + case MLOperator::OperatorKind::kLinear: + return "linear"; + case MLOperator::OperatorKind::kSoftplus: + return "softplus"; + case MLOperator::OperatorKind::kSoftsign: + return "softsign"; + case MLOperator::OperatorKind::kTanh: + return "tanh"; + case MLOperator::OperatorKind::kL2Pool2d: + return "l2Pool"; + case MLOperator::OperatorKind::kAbs: + return "abs"; + case MLOperator::OperatorKind::kNeg: + return "neg"; + case MLOperator::OperatorKind::kLog: + return "log"; + case MLOperator::OperatorKind::kGru: + return "gru"; + case MLOperator::OperatorKind::kGruCell: + return "grulCell"; + case MLOperator::OperatorKind::kLstm: + return "lstm"; + case MLOperator::OperatorKind::kLstmCell: + return "lstmCell"; + case MLOperator::OperatorKind::kConv2dInteger: + return "conv2dInteger"; + case MLOperator::OperatorKind::kDequantizeLinear: + return "dequantizeLinear"; + case MLOperator::OperatorKind::kDynamicQuantizeLinear: + return "dynamicQuantizeLinear"; + case MLOperator::OperatorKind::kMatmulInteger: + return "matmulInteger"; + default: + return "unknown"; } } @@ -87,7 +224,6 @@ const HeapVector>& MLOperator::Outputs() const { void MLOperator::Connect(HeapVector> inputs, HeapVector> outputs) { DCHECK(!is_connected_); - DCHECK(!inputs.empty()); DCHECK(!outputs.empty()); inputs_ = std::move(inputs); outputs_ = std::move(outputs); diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_operator.h b/third_party/blink/renderer/modules/ml/webnn/ml_operator.h index bb49f27f564886..fcaf59d12a16c1 100644 --- a/third_party/blink/renderer/modules/ml/webnn/ml_operator.h +++ b/third_party/blink/renderer/modules/ml/webnn/ml_operator.h @@ -21,25 +21,122 @@ class MODULES_EXPORT MLOperator final : public ScriptWrappable { DEFINE_WRAPPERTYPEINFO(); public: + // TODO: Delete this enum and with the ml::webnn::mojom::blink::OperatorType + // generated one - one less enum to keep in sync. enum class OperatorKind { // Keep the order as the same as build methods of MLGraphBuilder. - kClamp, + kUnknown, + kConv2d, + kConvTranspose2d, + kGemm, + kMatmul, + + // Elementwise binary operations kAdd, kSub, kMul, kDiv, kMax, kMin, - kGemm, + kPow, + + // Elementwise binary logical comparison operations + kEqual, + kGreater, + kLesser, + + // Elementwise unary activation operations. + kRelu, + kElu, + kPrelu, + kLeakyRelu, + kClamp, + kSigmoid, + kHardSigmoid, kHardSwish, + kLinear, + kSoftplus, + kSoftsign, + kSoftmax, + kTanh, + + // Pooling operations kAveragePool2d, + kL2Pool2d, kMaxPool2d, - kRelu, + + // Elementwise unary operations + kIdentity, + kAbs, + kNeg, + kExp, + kLog, + kSqrt, + kSin, + kCos, + kTan, + kErf, + kFloor, + kCeil, + kReciprocal, + kLogicalNot, + + // Trinary elementwise operations + kElementWiseIf, + + // Shape reinterpretation operations. kReshape, + kSqueeze, + kUnsqueeze, + kFlattenTo2d, + + // Shape modification operations. + kConcat, + kSlice, + kSplit, + kTranspose, + kPad, + kExpand, + kGather, kResample2d, - kSoftmax, - kSigmoid + + // Reduction operations + kReduceL1, + kReduceL2, + kReduceLogSum, + kReduceLogSumExp, + kReduceMax, + kReduceMean, + kReduceMin, + kReduceProduct, + kReduceSum, + kReduceSumSquare, + + // Normalization operations + kInstanceNormalization, + kMeanVarianceNormalization, + + // Miscellaneous + kArgMax, + kArgMin, + kCast, + kFillSequence, + kTriangularMatrix, + + // Iterative GEMM operations + kGru, + kGruCell, + kLstm, + kLstmCell, + + // Quantized operators + kConv2dInteger, + kMatmulInteger, + kDequantizeLinear, + kDynamicQuantizeLinear, + + kTotal, // Total number of enumerants for static assertions. }; static String OperatorKindToString(MLOperator::OperatorKind kind); @@ -48,7 +145,7 @@ class MODULES_EXPORT MLOperator final : public ScriptWrappable { // that passes the reference of the options dictionary argument received from // Blink to MLOperator constructor and stores it in this object. This is // because that WebIDL spec (https://webidl.spec.whatwg.org/#idl-dictionaries) - // mentiones that "an operation that accepts a dictionary as an argument will + // mentions that "an operation that accepts a dictionary as an argument will // perform a one-time conversion from the given ECMAScript value into the // dictionary, based on the current properties of the ECMAScript object. // Modifications to the dictionary will not be reflected in the corresponding diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_client.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_client.cc index 4488e1292565e2..8c2f03b5ce180d 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_client.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_client.cc @@ -8,6 +8,9 @@ #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h" #include "third_party/blink/renderer/core/dom/dom_exception.h" +// TODO::: +#pragma optimize("", off) + namespace blink { MojoClient::MojoClient(ExecutionContext* execution_context) diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc index 0af6e52041a69c..422355db9a8415 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.cc @@ -4,6 +4,8 @@ #include "third_party/blink/renderer/modules/ml/webnn/mojo_graph.h" +#include "base/trace_event/trace_event.h" +#include "base/trace_event/typed_macros.h" #include "mojo/public/cpp/bindings/pending_remote.h" #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor.h" @@ -21,6 +23,9 @@ #include +// TODO::: +#pragma optimize("", off) + namespace blink { namespace { @@ -30,11 +35,15 @@ using ml::webnn::mojom::blink::ComputeResult; using ml::webnn::mojom::blink::MemoryInfoPtr; void AddOperation(MojoModelInfo* model_info, const MLOperator* op) { + static_assert(int32_t(MLOperator::OperatorKind::kTotal) == 83); + switch (op->Kind()) { case MLOperator::OperatorKind::kClamp: model_info->AddClamp(op); break; case MLOperator::OperatorKind::kConv2d: + case MLOperator::OperatorKind::kConv2dInteger: + case MLOperator::OperatorKind::kConvTranspose2d: model_info->AddConv2d(op); break; case MLOperator::OperatorKind::kAdd: @@ -43,24 +52,143 @@ void AddOperation(MojoModelInfo* model_info, const MLOperator* op) { case MLOperator::OperatorKind::kDiv: case MLOperator::OperatorKind::kMin: case MLOperator::OperatorKind::kMax: + case MLOperator::OperatorKind::kEqual: + case MLOperator::OperatorKind::kGreater: + case MLOperator::OperatorKind::kLesser: + case MLOperator::OperatorKind::kPow: + case MLOperator::OperatorKind::kPrelu: model_info->AddElementWiseBinary(op); break; case MLOperator::OperatorKind::kGemm: + case MLOperator::OperatorKind::kMatmul: + case MLOperator::OperatorKind::kMatmulInteger: model_info->AddGemm(op); break; case MLOperator::OperatorKind::kAveragePool2d: + case MLOperator::OperatorKind::kL2Pool2d: case MLOperator::OperatorKind::kMaxPool2d: model_info->AddPool2d(op); break; case MLOperator::OperatorKind::kRelu: model_info->AddRelu(op); break; + case MLOperator::OperatorKind::kResample2d: + model_info->AddResample2d(op); + break; case MLOperator::OperatorKind::kSoftmax: model_info->AddSoftmax(op); break; case MLOperator::OperatorKind::kReshape: + case MLOperator::OperatorKind::kSqueeze: + case MLOperator::OperatorKind::kUnsqueeze: + case MLOperator::OperatorKind::kFlattenTo2d: model_info->AddReshape(op); break; + case MLOperator::OperatorKind::kArgMax: + case MLOperator::OperatorKind::kArgMin: + model_info->AddArgMinMax(op); + break; + case MLOperator::OperatorKind::kCast: + model_info->AddCast(op); + break; + case MLOperator::OperatorKind::kConcat: + model_info->AddConcat(op); + break; + case MLOperator::OperatorKind::kSlice: + model_info->AddSlice(op); + break; + case MLOperator::OperatorKind::kSplit: + model_info->AddSplit(op); + break; + case MLOperator::OperatorKind::kExpand: + model_info->AddExpand(op); + break; + case MLOperator::OperatorKind::kIdentity: + case MLOperator::OperatorKind::kAbs: + case MLOperator::OperatorKind::kNeg: + case MLOperator::OperatorKind::kCos: + case MLOperator::OperatorKind::kExp: + case MLOperator::OperatorKind::kLog: + case MLOperator::OperatorKind::kSqrt: + case MLOperator::OperatorKind::kSin: + case MLOperator::OperatorKind::kTan: + case MLOperator::OperatorKind::kTanh: + case MLOperator::OperatorKind::kErf: + case MLOperator::OperatorKind::kFloor: + case MLOperator::OperatorKind::kCeil: + case MLOperator::OperatorKind::kReciprocal: + case MLOperator::OperatorKind::kLogicalNot: + case MLOperator::OperatorKind::kSigmoid: + case MLOperator::OperatorKind::kHardSwish: + case MLOperator::OperatorKind::kSoftsign: + model_info->AddElementWiseUnary(op); + break; + case MLOperator::OperatorKind::kElementWiseIf: + model_info->AddElementWiseIf(op); + break; + case MLOperator::OperatorKind::kGather: + model_info->AddGather(op); + break; + case MLOperator::OperatorKind::kInstanceNormalization: + model_info->AddInstanceNormalization(op); + break; + case MLOperator::OperatorKind::kMeanVarianceNormalization: + model_info->AddMeanVarianceNormalization(op); + break; + case MLOperator::OperatorKind::kPad: + model_info->AddPad(op); + break; + case MLOperator::OperatorKind::kFillSequence: + model_info->AddFillSequence(op); + break; + case MLOperator::OperatorKind::kReduceL1: + case MLOperator::OperatorKind::kReduceL2: + case MLOperator::OperatorKind::kReduceLogSum: + case MLOperator::OperatorKind::kReduceLogSumExp: + case MLOperator::OperatorKind::kReduceMax: + case MLOperator::OperatorKind::kReduceMean: + case MLOperator::OperatorKind::kReduceMin: + case MLOperator::OperatorKind::kReduceProduct: + case MLOperator::OperatorKind::kReduceSum: + case MLOperator::OperatorKind::kReduceSumSquare: + model_info->AddReduce(op); + break; + case MLOperator::OperatorKind::kTranspose: + model_info->AddTranspose(op); + break; + case MLOperator::OperatorKind::kTriangularMatrix: + model_info->AddTriangularMatrix(op); + break; + + case MLOperator::OperatorKind::kElu: + case MLOperator::OperatorKind::kLeakyRelu: + case MLOperator::OperatorKind::kLinear: + case MLOperator::OperatorKind::kHardSigmoid: + case MLOperator::OperatorKind::kSoftplus: + model_info->AddElementWiseUnaryTwoParameter(op); + break; + case MLOperator::OperatorKind::kDequantizeLinear: + model_info->AddDequantizeLinear(op); + break; + case MLOperator::OperatorKind::kDynamicQuantizeLinear: + model_info->AddDynamicQuantizeLinear(op); + break; + +#if 0 // TODO:::Implement + case MLOperator::OperatorKind::kLstm: + model_info->AddLstm(op); + break; + case MLOperator::OperatorKind::kLstmCell: + model_info->AddLstmCell(op); + break; + case MLOperator::OperatorKind::kGru: + model_info->AddGru(op); + break; + case MLOperator::OperatorKind::kGruCell: + model_info->AddGruCell(op); + break; +#endif + default: NOTIMPLEMENTED(); break; @@ -192,6 +320,7 @@ void MojoGraph::ComputeAsyncImpl(const MLNamedArrayBufferViews& inputs, void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, const MLNamedArrayBufferViews& outputs, ExceptionState& exception_state) { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl"); if (inputs.size() != input_resources_info_.size()) { exception_state.ThrowDOMException(DOMExceptionCode::kDataError, "The number of inputs is invalid."); @@ -199,24 +328,28 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, } auto named_inputs = ml::webnn::mojom::blink::NamedResources::New(), named_outputs = ml::webnn::mojom::blink::NamedResources::New(); - for (const auto& input : inputs) { - String error_message; - auto* input_array_buffer_view = input.second.Get(); - if (input_array_buffer_view == nullptr) { - exception_state.ThrowDOMException(DOMExceptionCode::kDataError, - error_message); + { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyInputs"); + for (const auto& input : inputs) { + String error_message; + auto* input_array_buffer_view = input.second.Get(); + if (input_array_buffer_view == nullptr) { + exception_state.ThrowDOMException(DOMExceptionCode::kDataError, + error_message); + } + const String& input_name = input.first; + auto memory_info = ml::webnn::mojom::blink::MemoryInfo::New(); + memory_info->byte_offset = inputs_byte_offset_.at(input_name); + memory_info->byte_length = + input_resources_info_.at(input_name).byte_length; + uint8_t* address = inputs_shm_region_.mapping.GetMemoryAs() + + memory_info->byte_offset; + memcpy(address, input_array_buffer_view->BaseAddressMaybeShared(), + input_array_buffer_view->byteLength()); + named_inputs->resources.insert(input_name, std::move(memory_info)); } - const String& input_name = input.first; - auto memory_info = ml::webnn::mojom::blink::MemoryInfo::New(); - memory_info->byte_offset = inputs_byte_offset_.at(input_name); - memory_info->byte_length = input_resources_info_.at(input_name).byte_length; - uint8_t* address = inputs_shm_region_.mapping.GetMemoryAs() + - memory_info->byte_offset; - memcpy(address, input_array_buffer_view->BaseAddressMaybeShared(), - input_array_buffer_view->byteLength()); - named_inputs->resources.insert(input_name, std::move(memory_info)); + named_inputs->shared_memory = inputs_shm_region_.region.Duplicate(); } - named_inputs->shared_memory = inputs_shm_region_.region.Duplicate(); ComputeResult result; if (!remote_graph_->Compute(std::move(named_inputs), &result, &named_outputs)) { @@ -224,29 +357,33 @@ void MojoGraph::ComputeSyncImpl(const MLNamedArrayBufferViews& inputs, "Failed to compute the graph."); return; }; - for (const auto& output : outputs) { - String error_message; - void* output_buffer_address = output.second->BaseAddressMaybeShared(); - if (output_buffer_address == nullptr) { - exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, - error_message); - return; + { + TRACE_EVENT0("blink", "MojoGraph::ComputeSyncImpl::CopyOutputs"); + if (!outputs_shm_mapping_.IsValid()) { + outputs_shm_mapping_ = named_outputs->shared_memory.Map(); } - auto iter = named_outputs->resources.find(output.first); - if (iter == named_outputs->resources.end()) { - exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, - "Failed to get result for the output."); - return; + for (const auto& output : outputs) { + String error_message; + void* output_buffer_address = output.second->BaseAddressMaybeShared(); + if (output_buffer_address == nullptr) { + exception_state.ThrowDOMException(DOMExceptionCode::kOperationError, + error_message); + return; + } + auto iter = named_outputs->resources.find(output.first); + if (iter == named_outputs->resources.end()) { + exception_state.ThrowDOMException( + DOMExceptionCode::kOperationError, + "Failed to get result for the output."); + return; + } + MemoryInfoPtr memory_info = std::move(iter->value); + size_t byte_offset = base::checked_cast(memory_info->byte_offset); + size_t byte_length = base::checked_cast(memory_info->byte_length); + memcpy(output_buffer_address, + outputs_shm_mapping_.GetMemoryAs() + byte_offset, + byte_length); } - MemoryInfoPtr memory_info = std::move(iter->value); - base::ReadOnlySharedMemoryRegion& shared_memory_region = - named_outputs->shared_memory; - DCHECK(shared_memory_region.IsValid()); - size_t byte_length = base::checked_cast(memory_info->byte_length); - base::ReadOnlySharedMemoryMapping shared_memory_mapping = - shared_memory_region.MapAt(memory_info->byte_offset, byte_length); - memcpy(output_buffer_address, shared_memory_mapping.GetMemoryAs(), - byte_length); } } @@ -280,9 +417,15 @@ void MojoGraph::OnGraphCreated( inputs_byte_offset_.insert(input->Name(), aligned_offset.ValueOrDie()); aligned_offset += Align(input_byte_length, kBufferAlignment).ValueOrDie(); } + size_t inputs_buffer_length = aligned_offset.ValueOrDie(); inputs_shm_region_ = base::ReadOnlySharedMemoryRegion::Create(inputs_buffer_length); + if (!inputs_shm_region_.IsValid()) { + resolver->Reject(MakeGarbageCollected( + DOMExceptionCode::kUnknownError, + "Failed to create graph input shared memory.")); + } for (const auto& constant : constants) { model_info->AddConstant(constant.Get()); diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h index a0ffdda736e3b1..ebe132df1338ea 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_graph.h @@ -66,6 +66,7 @@ class MojoGraph : public MLGraph { // The map of input name and input data offset. HashMap inputs_byte_offset_; base::MappedReadOnlyRegion inputs_shm_region_; + base::ReadOnlySharedMemoryMapping outputs_shm_mapping_; HeapMojoRemote remote_graph_; }; diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc index 9e9a150ad0ff22..d4198299edd22a 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.cc @@ -6,17 +6,43 @@ #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_options_internal.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_descriptor.h" #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_slice_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_mean_variance_normalization_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_fill_sequence_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_float_parameter_options_internal.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h" +#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options_internal.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h" #include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h" #include "third_party/blink/renderer/modules/ml/webnn/mojo_graph.h" #include "third_party/blink/renderer/platform/bindings/exception_code.h" #include "third_party/blink/renderer/platform/bindings/exception_state.h" +// TODO::: +#pragma optimize("", off) + namespace blink { +#define DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(/*MLOperator* */ ml_operator, \ + /*uint32_t*/ expected_input_count, \ + /*uint32_t*/ expected_output_count) \ + DCHECK_EQ(ml_operator->Inputs().size(), expected_input_count); \ + DCHECK_EQ(ml_operator->Outputs().size(), expected_output_count); + base::CheckedNumeric Align(size_t value, uint32_t aligment) { size_t remainder = value % aligment; if (remainder != 0) { @@ -30,15 +56,18 @@ namespace { using ml::webnn::mojom::blink::AutoPad; using ml::webnn::mojom::blink::Conv2dFilterOperandLayout; -using ml::webnn::mojom::blink::ElementWiseBinaryType; +using ml::webnn::mojom::blink::OperatorType; using ml::webnn::mojom::blink::InputOperandLayout; using ml::webnn::mojom::blink::OperandType; using ml::webnn::mojom::blink::OperationInfo; using ml::webnn::mojom::blink::OperationInfoPtr; -using ml::webnn::mojom::blink::Pool2dType; using ml::webnn::mojom::blink::RoundingType; +using ml::webnn::mojom::blink::PaddingMode; OperandType BlinkOperandTypeToMojo(V8MLOperandType::Enum type) { + static_assert(V8MLOperandType::kEnumSize == 8); + static_assert(int32_t(ml::webnn::mojom::blink::OperandType::kMaxValue) + 1 == 8); + switch (type) { case V8MLOperandType::Enum::kFloat32: return OperandType::kFloat32; @@ -48,6 +77,10 @@ OperandType BlinkOperandTypeToMojo(V8MLOperandType::Enum type) { return OperandType::kInt32; case V8MLOperandType::Enum::kUint32: return OperandType::kUint32; + case V8MLOperandType::Enum::kInt64: + return OperandType::kInt64; + case V8MLOperandType::Enum::kUint64: + return OperandType::kUint32; case V8MLOperandType::Enum::kInt8: return OperandType::kInt8; case V8MLOperandType::Enum::kUint8: @@ -57,6 +90,7 @@ OperandType BlinkOperandTypeToMojo(V8MLOperandType::Enum type) { InputOperandLayout BlinkInputOperandLayoutToMojo( V8MLInputOperandLayout::Enum type) { + static_assert(V8MLInputOperandLayout::kEnumSize == 2); switch (type) { case V8MLInputOperandLayout::Enum::kNchw: return InputOperandLayout::kNchw; @@ -66,20 +100,28 @@ InputOperandLayout BlinkInputOperandLayoutToMojo( } Conv2dFilterOperandLayout BlinkConv2dFilterOperandLayoutToMojo( - V8MLConv2dFilterOperandLayout::Enum type) { + V8MLConvFilterOperandLayoutInternal::Enum type) { + static_assert(V8MLConvFilterOperandLayoutInternal::kEnumSize == 6); + static_assert(int32_t(Conv2dFilterOperandLayout::kMaxValue) + 1 == 6); + switch (type) { - case V8MLConv2dFilterOperandLayout::Enum::kOihw: + case V8MLConvFilterOperandLayoutInternal::Enum::kOihw: return Conv2dFilterOperandLayout::kOihw; - case V8MLConv2dFilterOperandLayout::Enum::kHwio: + case V8MLConvFilterOperandLayoutInternal::Enum::kIohw: + return Conv2dFilterOperandLayout::kIohw; + case V8MLConvFilterOperandLayoutInternal::Enum::kHwoi: + return Conv2dFilterOperandLayout::kHwoi; + case V8MLConvFilterOperandLayoutInternal::Enum::kHwio: return Conv2dFilterOperandLayout::kHwio; - case V8MLConv2dFilterOperandLayout::Enum::kOhwi: + case V8MLConvFilterOperandLayoutInternal::Enum::kOhwi: return Conv2dFilterOperandLayout::kOhwi; - case V8MLConv2dFilterOperandLayout::Enum::kIhwo: + case V8MLConvFilterOperandLayoutInternal::Enum::kIhwo: return Conv2dFilterOperandLayout::kIhwo; } } AutoPad BlinkAutoPadToMojo(V8MLAutoPad::Enum type) { + static_assert(V8MLAutoPad::kEnumSize == 3); switch (type) { case V8MLAutoPad::Enum::kExplicit: return AutoPad::kExplicit; @@ -91,6 +133,7 @@ AutoPad BlinkAutoPadToMojo(V8MLAutoPad::Enum type) { } RoundingType BlinkRoundingTypeToMojo(V8MLRoundingType::Enum type) { + static_assert(V8MLRoundingType::kEnumSize == 2); switch (type) { case V8MLRoundingType::Enum::kFloor: return RoundingType::kFloor; @@ -99,28 +142,115 @@ RoundingType BlinkRoundingTypeToMojo(V8MLRoundingType::Enum type) { } } -Pool2dType BlinkPool2dTypeToMojo(MLOperator::OperatorKind type) { - switch (type) { - case MLOperator::OperatorKind::kAveragePool2d: - return Pool2dType::kAveragePool2d; - default: - NOTREACHED(); - return Pool2dType::kUnknown; - } +OperatorType BlinkOperatorKindToMojoType( + MLOperator::OperatorKind type) { + static_assert(int32_t(MLOperator::OperatorKind::kTotal) == 83); + static_assert(int32_t(OperatorType::kMaxValue) + 1 == 83); + + // Keep these two enumerations in sync. + // Favor readability over convention here. + // clang-format off + static_assert(uint32_t(MLOperator::OperatorKind::kConv2d) == uint32_t(OperatorType::kConv2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kConvTranspose2d) == uint32_t(OperatorType::kConvTranspose2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kGemm) == uint32_t(OperatorType::kGemm)); + static_assert(uint32_t(MLOperator::OperatorKind::kMatmul) == uint32_t(OperatorType::kMatmul)); + static_assert(uint32_t(MLOperator::OperatorKind::kAdd) == uint32_t(OperatorType::kAdd)); + static_assert(uint32_t(MLOperator::OperatorKind::kSub) == uint32_t(OperatorType::kSub)); + static_assert(uint32_t(MLOperator::OperatorKind::kMul) == uint32_t(OperatorType::kMul)); + static_assert(uint32_t(MLOperator::OperatorKind::kDiv) == uint32_t(OperatorType::kDiv)); + static_assert(uint32_t(MLOperator::OperatorKind::kMax) == uint32_t(OperatorType::kMax)); + static_assert(uint32_t(MLOperator::OperatorKind::kMin) == uint32_t(OperatorType::kMin)); + static_assert(uint32_t(MLOperator::OperatorKind::kPow) == uint32_t(OperatorType::kPow)); + static_assert(uint32_t(MLOperator::OperatorKind::kEqual) == uint32_t(OperatorType::kEqual)); + static_assert(uint32_t(MLOperator::OperatorKind::kGreater) == uint32_t(OperatorType::kGreater)); + static_assert(uint32_t(MLOperator::OperatorKind::kLesser) == uint32_t(OperatorType::kLesser)); + static_assert(uint32_t(MLOperator::OperatorKind::kRelu) == uint32_t(OperatorType::kRelu)); + static_assert(uint32_t(MLOperator::OperatorKind::kElu) == uint32_t(OperatorType::kElu)); + static_assert(uint32_t(MLOperator::OperatorKind::kPrelu) == uint32_t(OperatorType::kPrelu)); + static_assert(uint32_t(MLOperator::OperatorKind::kLeakyRelu) == uint32_t(OperatorType::kLeakyRelu)); + static_assert(uint32_t(MLOperator::OperatorKind::kClamp) == uint32_t(OperatorType::kClamp)); + static_assert(uint32_t(MLOperator::OperatorKind::kSigmoid) == uint32_t(OperatorType::kSigmoid)); + static_assert(uint32_t(MLOperator::OperatorKind::kHardSigmoid) == uint32_t(OperatorType::kHardSigmoid)); + static_assert(uint32_t(MLOperator::OperatorKind::kHardSwish) == uint32_t(OperatorType::kHardSwish)); + static_assert(uint32_t(MLOperator::OperatorKind::kLinear) == uint32_t(OperatorType::kLinear)); + static_assert(uint32_t(MLOperator::OperatorKind::kSoftplus) == uint32_t(OperatorType::kSoftplus)); + static_assert(uint32_t(MLOperator::OperatorKind::kSoftsign) == uint32_t(OperatorType::kSoftsign)); + static_assert(uint32_t(MLOperator::OperatorKind::kSoftmax) == uint32_t(OperatorType::kSoftmax)); + static_assert(uint32_t(MLOperator::OperatorKind::kAveragePool2d) == uint32_t(OperatorType::kAveragePool2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kMaxPool2d) == uint32_t(OperatorType::kMaxPool2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kIdentity) == uint32_t(OperatorType::kIdentity)); + static_assert(uint32_t(MLOperator::OperatorKind::kAbs) == uint32_t(OperatorType::kAbs)); + static_assert(uint32_t(MLOperator::OperatorKind::kNeg) == uint32_t(OperatorType::kNeg)); + static_assert(uint32_t(MLOperator::OperatorKind::kExp) == uint32_t(OperatorType::kExp)); + static_assert(uint32_t(MLOperator::OperatorKind::kLog) == uint32_t(OperatorType::kLog)); + static_assert(uint32_t(MLOperator::OperatorKind::kSqrt) == uint32_t(OperatorType::kSqrt)); + static_assert(uint32_t(MLOperator::OperatorKind::kSin) == uint32_t(OperatorType::kSin)); + static_assert(uint32_t(MLOperator::OperatorKind::kCos) == uint32_t(OperatorType::kCos)); + static_assert(uint32_t(MLOperator::OperatorKind::kTan) == uint32_t(OperatorType::kTan)); + static_assert(uint32_t(MLOperator::OperatorKind::kTanh) == uint32_t(OperatorType::kTanh)); + static_assert(uint32_t(MLOperator::OperatorKind::kErf) == uint32_t(OperatorType::kErf)); + static_assert(uint32_t(MLOperator::OperatorKind::kFloor) == uint32_t(OperatorType::kFloor)); + static_assert(uint32_t(MLOperator::OperatorKind::kCeil) == uint32_t(OperatorType::kCeil)); + static_assert(uint32_t(MLOperator::OperatorKind::kReciprocal) == uint32_t(OperatorType::kReciprocal)); + static_assert(uint32_t(MLOperator::OperatorKind::kLogicalNot) == uint32_t(OperatorType::kLogicalNot)); + static_assert(uint32_t(MLOperator::OperatorKind::kElementWiseIf) == uint32_t(OperatorType::kElementWiseIf)); + static_assert(uint32_t(MLOperator::OperatorKind::kReshape) == uint32_t(OperatorType::kReshape)); + static_assert(uint32_t(MLOperator::OperatorKind::kSqueeze) == uint32_t(OperatorType::kSqueeze)); + static_assert(uint32_t(MLOperator::OperatorKind::kUnsqueeze) == uint32_t(OperatorType::kUnsqueeze)); + static_assert(uint32_t(MLOperator::OperatorKind::kFlattenTo2d) == uint32_t(OperatorType::kFlattenTo2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kConcat) == uint32_t(OperatorType::kConcat)); + static_assert(uint32_t(MLOperator::OperatorKind::kSlice) == uint32_t(OperatorType::kSlice)); + static_assert(uint32_t(MLOperator::OperatorKind::kSplit) == uint32_t(OperatorType::kSplit)); + static_assert(uint32_t(MLOperator::OperatorKind::kTranspose) == uint32_t(OperatorType::kTranspose)); + static_assert(uint32_t(MLOperator::OperatorKind::kPad) == uint32_t(OperatorType::kPad)); + static_assert(uint32_t(MLOperator::OperatorKind::kExpand) == uint32_t(OperatorType::kExpand)); + static_assert(uint32_t(MLOperator::OperatorKind::kGather) == uint32_t(OperatorType::kGather)); + static_assert(uint32_t(MLOperator::OperatorKind::kResample2d) == uint32_t(OperatorType::kResample2d)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceL1) == uint32_t(OperatorType::kReduceL1)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceL2) == uint32_t(OperatorType::kReduceL2)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceLogSum) == uint32_t(OperatorType::kReduceLogSum)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceLogSumExp) == uint32_t(OperatorType::kReduceLogSumExp)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceMax) == uint32_t(OperatorType::kReduceMax)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceMean) == uint32_t(OperatorType::kReduceMean)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceMin) == uint32_t(OperatorType::kReduceMin)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceProduct) == uint32_t(OperatorType::kReduceProduct)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceSum) == uint32_t(OperatorType::kReduceSum)); + static_assert(uint32_t(MLOperator::OperatorKind::kReduceSumSquare) == uint32_t(OperatorType::kReduceSumSquare)); + static_assert(uint32_t(MLOperator::OperatorKind::kArgMax) == uint32_t(OperatorType::kArgMax)); + static_assert(uint32_t(MLOperator::OperatorKind::kArgMin) == uint32_t(OperatorType::kArgMin)); + static_assert(uint32_t(MLOperator::OperatorKind::kCast) == uint32_t(OperatorType::kCast)); + static_assert(uint32_t(MLOperator::OperatorKind::kInstanceNormalization) == uint32_t(OperatorType::kInstanceNormalization)); + static_assert(uint32_t(MLOperator::OperatorKind::kMeanVarianceNormalization) == uint32_t(OperatorType::kMeanVarianceNormalization)); + static_assert(uint32_t(MLOperator::OperatorKind::kFillSequence) == uint32_t(OperatorType::kFillSequence)); + static_assert(uint32_t(MLOperator::OperatorKind::kTriangularMatrix) == uint32_t(OperatorType::kTriangularMatrix)); + static_assert(uint32_t(MLOperator::OperatorKind::kGru) == uint32_t(OperatorType::kGru)); + static_assert(uint32_t(MLOperator::OperatorKind::kGruCell) == uint32_t(OperatorType::kGruCell)); + static_assert(uint32_t(MLOperator::OperatorKind::kLstm) == uint32_t(OperatorType::kLstm)); + static_assert(uint32_t(MLOperator::OperatorKind::kLstmCell) == uint32_t(OperatorType::kLstmCell)); + static_assert(uint32_t(MLOperator::OperatorKind::kConv2dInteger) == uint32_t(OperatorType::kConv2dInteger)); + static_assert(uint32_t(MLOperator::OperatorKind::kMatmulInteger) == uint32_t(OperatorType::kMatmulInteger)); + static_assert(uint32_t(MLOperator::OperatorKind::kDequantizeLinear) == uint32_t(OperatorType::kDequantizeLinear)); + static_assert(uint32_t(MLOperator::OperatorKind::kDynamicQuantizeLinear) == uint32_t(OperatorType::kDynamicQuantizeLinear)); + // clang-format on + + return static_cast(type); } -ElementWiseBinaryType BlinkElementWiseBinaryTypeToMojo( - MLOperator::OperatorKind type) { +PaddingMode BlinkPaddingModeToMojo(V8MLPaddingMode::Enum type) { + static_assert(V8MLPaddingMode::kEnumSize == 4); switch (type) { - case MLOperator::OperatorKind::kAdd: - return ElementWiseBinaryType::kAdd; - default: - NOTREACHED(); - return ElementWiseBinaryType::kUnknown; + case V8MLPaddingMode::Enum::kConstant: + return PaddingMode::kConstant; + case V8MLPaddingMode::Enum::kEdge: + return PaddingMode::kEdge; + case V8MLPaddingMode::Enum::kReflection: + return PaddingMode::kReflection; + case V8MLPaddingMode::Enum::kSymmetric: + return PaddingMode::kSymmetric; } } -ml::webnn::mojom::blink::ClampOptionsPtr BlinkClampOptioinToMojo( +ml::webnn::mojom::blink::ClampOptionsPtr BlinkClampOptionsToMojo( const MLClampOptions* ml_options) { const float min = ml_options->hasMinValue() ? ml_options->minValue() @@ -140,7 +270,7 @@ OperationInfoPtr FusionOperation(const MLOperator* activation) { case MLOperator::OperatorKind::kClamp: { auto clamp = ml::webnn::mojom::blink::Clamp::New(); clamp->input_index = std::numeric_limits::max(); - clamp->options = BlinkClampOptioinToMojo( + clamp->options = BlinkClampOptionsToMojo( static_cast(activation->Options())); clamp->output_index = std::numeric_limits::max(); auto operation = OperationInfo::NewClamp(std::move(clamp)); @@ -160,8 +290,8 @@ OperationInfoPtr FusionOperation(const MLOperator* activation) { } } -ml::webnn::mojom::blink::Conv2dOptionsPtr BlinkConv2dOptioinToMojo( - const MLConv2dOptions* ml_options, +ml::webnn::mojom::blink::Conv2dOptionsPtr BlinkConv2dOptionsToMojo( + const MLConvOptionsInternal* ml_options, const HeapHashMap, size_t>& operand_index_map) { auto options = ml::webnn::mojom::blink::Conv2dOptions::New(); options->padding = @@ -188,7 +318,7 @@ ml::webnn::mojom::blink::Conv2dOptionsPtr BlinkConv2dOptioinToMojo( return options; } -ml::webnn::mojom::blink::Pool2dOptionsPtr BlinkPool2dOptioinToMojo( +ml::webnn::mojom::blink::Pool2dOptionsPtr BlinkPool2dOptionsToMojo( const MLPool2dOptions* ml_options) { auto options = ml::webnn::mojom::blink::Pool2dOptions::New(); options->window_dimensions = ml_options->hasWindowDimensions() @@ -213,7 +343,7 @@ ml::webnn::mojom::blink::Pool2dOptionsPtr BlinkPool2dOptioinToMojo( return options; } -ml::webnn::mojom::blink::GemmOptionsPtr BlinkGemmOptioinToMojo( +ml::webnn::mojom::blink::GemmOptionsPtr BlinkGemmOptionsToMojo( const MLGemmOptions* ml_options, const HeapHashMap, size_t>& operand_index_map) { auto options = ml::webnn::mojom::blink::GemmOptions::New(); @@ -269,6 +399,26 @@ void MojoModelInfo::AddOutput(String name, const MLOperand* output) { model_info_->outputs.push_back(std::move(named_output)); } +bool MojoModelInfo::AreOperandsInIndexMap(const Member* operands, + size_t operand_count) const { + for (size_t i = 0; i < operand_count; ++i) { + if (!operand_index_map_.Contains(operands[i])) { + return false; + } + } + return true; +} + +uint64_t MojoModelInfo::GetOperandIndex(/*nullable*/ const MLOperand* operand) const +{ + if (operand == nullptr) + { + // Sentinel value represents no tensor. + return std::numeric_limits::max(); + } + return operand_index_map_.at(operand); +} + void MojoModelInfo::AddClamp(const MLOperator* ml_clamp) { DCHECK_EQ(ml_clamp->Inputs().size(), static_cast(1)); auto* input = ml_clamp->Inputs()[0].Get(); @@ -283,7 +433,7 @@ void MojoModelInfo::AddClamp(const MLOperator* ml_clamp) { // Add clamp operation to the model. auto clamp = ml::webnn::mojom::blink::Clamp::New(); clamp->input_index = operand_index_map_.at(input); - clamp->options = BlinkClampOptioinToMojo( + clamp->options = BlinkClampOptionsToMojo( static_cast(ml_clamp->Options())); clamp->output_index = output_index; auto operation = OperationInfo::NewClamp(std::move(clamp)); @@ -291,44 +441,104 @@ void MojoModelInfo::AddClamp(const MLOperator* ml_clamp) { } void MojoModelInfo::AddConv2d(const MLOperator* ml_conv2d) { + auto input_count = ml_conv2d->Inputs().size(); DCHECK_GE(ml_conv2d->Inputs().size(), static_cast(2)); auto* input = ml_conv2d->Inputs()[0].Get(); auto* filter = ml_conv2d->Inputs()[1].Get(); - if (operand_index_map_.find(input) == operand_index_map_.end() || - operand_index_map_.find(filter) == operand_index_map_.end()) { + auto* input_zero_point = (input_count >= 4) ? ml_conv2d->Inputs()[2].Get() : nullptr; + auto* filter_zero_point = (input_count >= 4) ? ml_conv2d->Inputs()[3].Get() : nullptr; + if (!operand_index_map_.Contains(input) || + !operand_index_map_.Contains(filter)) { return; } + // Add operand descriptor to the model. DCHECK_EQ(ml_conv2d->Outputs().size(), static_cast(1)); auto* output = ml_conv2d->Outputs()[0].Get(); DCHECK(operand_index_map_.find(output) == operand_index_map_.end()); size_t output_index = AddOperandToModel(output); - // Add clamp operation to the model. + + // Add conv2d or convTranspose2d operation to the model. auto conv2d = ml::webnn::mojom::blink::Conv2d::New(); + conv2d->operator_type = BlinkOperatorKindToMojoType(ml_conv2d->Kind()); conv2d->input_index = operand_index_map_.at(input); conv2d->filter_index = operand_index_map_.at(filter); - const MLConv2dOptions* ml_options = - static_cast(ml_conv2d->Options()); - conv2d->options = BlinkConv2dOptioinToMojo(ml_options, operand_index_map_); + conv2d->input_zero_point_index = GetOperandIndex(input_zero_point); + conv2d->filter_zero_point_index = GetOperandIndex(filter_zero_point); + const MLConvOptionsInternal* ml_options = + static_cast(ml_conv2d->Options()); + conv2d->options = BlinkConv2dOptionsToMojo(ml_options, operand_index_map_); conv2d->output_index = output_index; auto operation = OperationInfo::NewConv2d(std::move(conv2d)); model_info_->operations.push_back(std::move(operation)); } +void MojoModelInfo::AddElementWiseUnary(const MLOperator* ml_operator) { + DCHECK_EQ(ml_operator->Inputs().size(), static_cast(1)); + DCHECK_EQ(ml_operator->Outputs().size(), static_cast(1)); + + // Verify inputs exist and output does not yet exist. + auto* input = ml_operator->Inputs()[0].Get(); + auto* output = ml_operator->Outputs()[0].Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::ElementWiseUnary::New(); + mojom_operator->operator_type = BlinkOperatorKindToMojoType(ml_operator->Kind()); + mojom_operator->input_index = operand_index_map_.at(input); + mojom_operator->output_index = output_index; + auto operation = OperationInfo::NewElementWiseUnary(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation)); +} + +void MojoModelInfo::AddElementWiseUnaryTwoParameter(const MLOperator* ml_operator) { + DCHECK_EQ(ml_operator->Inputs().size(), static_cast(1)); + DCHECK_EQ(ml_operator->Outputs().size(), static_cast(1)); + + // Verify inputs exist and output does not yet exist. + auto* input = ml_operator->Inputs()[0].Get(); + auto* output = ml_operator->Outputs()[0].Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::ElementWiseUnaryTwoParameter::New(); + mojom_operator->operator_type = BlinkOperatorKindToMojoType(ml_operator->Kind()); + mojom_operator->input_index = operand_index_map_.at(input); + mojom_operator->output_index = output_index; + const MLFloatParameterOptionsInternal* ml_options = + static_cast(ml_operator->Options()); + mojom_operator->first_parameter = ml_options->firstParameter(); + mojom_operator->second_parameter = ml_options->secondParameter(); + auto operation = OperationInfo::NewElementWiseUnaryTwoParameter(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation)); +} + void MojoModelInfo::AddElementWiseBinary(const MLOperator* ml_binary) { DCHECK_EQ(ml_binary->Inputs().size(), static_cast(2)); + DCHECK_EQ(ml_binary->Outputs().size(), static_cast(1)); + + // Verify inputs exist and output does not yet exist. auto* a = ml_binary->Inputs()[0].Get(); auto* b = ml_binary->Inputs()[1].Get(); - if (operand_index_map_.find(a) == operand_index_map_.end() || - operand_index_map_.find(b) == operand_index_map_.end()) { + auto* output = ml_binary->Outputs()[0].Get(); + if (!operand_index_map_.Contains(a) || + !operand_index_map_.Contains(b)) { return; } - DCHECK_EQ(ml_binary->Outputs().size(), static_cast(1)); - auto* output = ml_binary->Outputs()[0].Get(); - DCHECK(operand_index_map_.find(output) == operand_index_map_.end()); + DCHECK(!operand_index_map_.Contains(output)); size_t output_index = AddOperandToModel(output); + + // Create mojom operator from JS blink type. auto binary = ml::webnn::mojom::blink::ElementWiseBinary::New(); - binary->type = BlinkElementWiseBinaryTypeToMojo(ml_binary->Kind()); + binary->operator_type = BlinkOperatorKindToMojoType(ml_binary->Kind()); binary->a_index = operand_index_map_.at(a); binary->b_index = operand_index_map_.at(b); binary->output_index = output_index; @@ -337,25 +547,35 @@ void MojoModelInfo::AddElementWiseBinary(const MLOperator* ml_binary) { } void MojoModelInfo::AddGemm(const MLOperator* ml_gemm) { - DCHECK_GE(ml_gemm->Inputs().size(), static_cast(2)); - auto* a = ml_gemm->Inputs()[0].Get(); - auto* b = ml_gemm->Inputs()[1].Get(); + auto input_count = ml_gemm->Inputs().size(); + auto& inputs = ml_gemm->Inputs(); + DCHECK_GE(input_count, static_cast(2)); + auto* a = inputs[0].Get(); + auto* b = inputs[1].Get(); + auto* a_zero_point = (input_count >= 4) ? inputs[2].Get() : nullptr; + auto* b_zero_point = (input_count >= 4) ? inputs[3].Get() : nullptr; if (operand_index_map_.find(a) == operand_index_map_.end() || operand_index_map_.find(b) == operand_index_map_.end()) { return; - } + } // TODO::: Add assert for a_zero_point and b_zero_point. + + // Add operand descriptor to the model. DCHECK_EQ(ml_gemm->Outputs().size(), static_cast(1)); auto* output = ml_gemm->Outputs()[0].Get(); DCHECK(operand_index_map_.find(output) == operand_index_map_.end()); size_t output_index = AddOperandToModel(output); - // Add clamp operation to the model. + + // Add GEMM operation to the model. auto gemm = ml::webnn::mojom::blink::Gemm::New(); + gemm->operator_type = BlinkOperatorKindToMojoType(ml_gemm->Kind()); gemm->a_index = operand_index_map_.at(a); gemm->b_index = operand_index_map_.at(b); + gemm->a_zero_point_index = GetOperandIndex(a_zero_point); + gemm->b_zero_point_index = GetOperandIndex(b_zero_point); const MLGemmOptions* ml_options = static_cast(ml_gemm->Options()); - gemm->options = BlinkGemmOptioinToMojo(ml_options, operand_index_map_); + gemm->options = BlinkGemmOptionsToMojo(ml_options, operand_index_map_); gemm->output_index = output_index; auto operation = OperationInfo::NewGemm(std::move(gemm)); model_info_->operations.push_back(std::move(operation)); @@ -372,13 +592,14 @@ void MojoModelInfo::AddPool2d(const MLOperator* ml_pool2d) { auto* output = ml_pool2d->Outputs()[0].Get(); DCHECK(operand_index_map_.find(output) == operand_index_map_.end()); size_t output_index = AddOperandToModel(output); - // Add averagePool2d operation to the model. + + // Add pooling operation to the model. auto pool2d = ml::webnn::mojom::blink::Pool2d::New(); - pool2d->type = BlinkPool2dTypeToMojo(ml_pool2d->Kind()); + pool2d->operator_type = BlinkOperatorKindToMojoType(ml_pool2d->Kind()); pool2d->input_index = operand_index_map_.at(input); const MLPool2dOptions* ml_options = static_cast(ml_pool2d->Options()); - pool2d->options = BlinkPool2dOptioinToMojo(ml_options); + pool2d->options = BlinkPool2dOptionsToMojo(ml_options); pool2d->output_index = output_index; auto operation = OperationInfo::NewPool2d(std::move(pool2d)); model_info_->operations.push_back(std::move(operation)); @@ -441,6 +662,467 @@ void MojoModelInfo::AddSoftmax(const MLOperator* ml_softmax) { model_info_->operations.push_back(std::move(operation)); } +void MojoModelInfo::AddArgMinMax(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + + size_t output_index = AddOperandToModel(output); + + const MLArgMinMaxOptions* ml_options = + static_cast(ml_operator->Options()); + + auto mojom_operator = ml::webnn::mojom::blink::ArgMinMax::New(); + mojom_operator->operator_type = BlinkOperatorKindToMojoType(ml_operator->Kind()); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->axis = ml_options->axis(); + mojom_operator->keep_dimensions = ml_options->keepDimensions(); + mojom_operator->select_last_index = ml_options->selectLastIndex(); + mojom_operator->output_index = output_index; + + auto operation_info = OperationInfo::NewArgMinMax(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddCast(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + // Verify inputs exist and output does not yet exist. + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Cast::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->output_index = output_index; + mojom_operator->data_type = BlinkOperandTypeToMojo(output->Type()); + auto operation_info = OperationInfo::NewCast(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddConcat(const MLOperator* ml_operator) { + DCHECK_GE(ml_operator->Inputs().size(), 0u); + DCHECK_EQ(ml_operator->Outputs().size(), 1u); + + // Verify inputs exist and output does not yet exist. + auto& inputs = ml_operator->Inputs(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) + { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLConcatOptionsInternal* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Concat::New(); + for (const Member& input : inputs) + { + mojom_operator->input_indices.push_back(GetOperandIndex(input)); + } + mojom_operator->output_index = output_index; + mojom_operator->axis = ml_options->axis(); + auto operation_info = OperationInfo::NewConcat(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddSlice(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + // Verify inputs exist and output does not yet exist. + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLSliceOptionsInternal* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Slice::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->starts = ml_options->starts(); + mojom_operator->sizes = ml_options->sizes(); + mojom_operator->output_index = output_index; + auto operation_info = OperationInfo::NewSlice(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddSplit(const MLOperator* ml_operator) { + DCHECK_EQ(ml_operator->Inputs().size(), 1u); + DCHECK_GE(ml_operator->Outputs().size(), 0u); + + // Verify inputs exist and output does not yet exist. + auto& outputs = ml_operator->Outputs(); + const MLOperand* input = ml_operator->Inputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!AreOperandsInIndexMap(outputs.data(), outputs.size())); + + const MLSplitOptionsInternal* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Split::New(); + for (const Member& output : outputs) + { + size_t output_index = AddOperandToModel(output); + mojom_operator->output_indices.push_back(output_index); + } + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->axis = ml_options->axis(); + auto operation_info = OperationInfo::NewSplit(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddExpand(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + auto mojom_operator = ml::webnn::mojom::blink::Expand::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->output_index = output_index; + + auto operation_info = OperationInfo::NewExpand(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddGather(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 2u, 1u); + + auto& inputs = ml_operator->Inputs(); + const MLOperand* input = inputs[0].Get(); + const MLOperand* indices = inputs[1].Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) + { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLGatherOptions* ml_options = + static_cast(ml_operator->Options()); + + auto mojom_operator = ml::webnn::mojom::blink::Gather::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->indices_index = GetOperandIndex(indices); + mojom_operator->axis = ml_options->axis(); + mojom_operator->output_index = output_index; + + auto operation_info = OperationInfo::NewGather(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddInstanceNormalization(const MLOperator* ml_operator) { + DCHECK_GE(ml_operator->Inputs().size(), 1u); + DCHECK_LE(ml_operator->Inputs().size(), 3u); + DCHECK_EQ(ml_operator->Outputs().size(), 1u); + + // Verify inputs exist and output does not yet exist. + auto& inputs = ml_operator->Inputs(); + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLInstanceNormalizationOptions* ml_options = + static_cast( + ml_operator->Options()); + + // Add another enum if extending this to ensure the two enums are castable below. + static_assert(InputOperandLayout::kMaxValue == InputOperandLayout::kNhwc); + static_assert(uint32_t(InputOperandLayout::kNchw) == + uint32_t(V8MLInputOperandLayout::Enum::kNchw)); + static_assert(uint32_t(InputOperandLayout::kNhwc) == + uint32_t(V8MLInputOperandLayout::Enum::kNhwc)); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::InstanceNormalization::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->scale_index = GetOperandIndex(ml_options->getScaleOr(nullptr)); + mojom_operator->bias_index = GetOperandIndex(ml_options->getBiasOr(nullptr)); + mojom_operator->epsilon = ml_options->epsilon(); + mojom_operator->layout = static_cast(ml_options->layout().AsEnum()); + mojom_operator->output_index = output_index; + auto operation_info = OperationInfo::NewInstanceNormalization(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddMeanVarianceNormalization(const MLOperator* ml_operator) { + DCHECK_GE(ml_operator->Inputs().size(), 1u); + DCHECK_LE(ml_operator->Inputs().size(), 5u); + DCHECK_EQ(ml_operator->Outputs().size(), 1u); + + // Verify inputs exist and output does not yet exist. + auto& inputs = ml_operator->Inputs(); + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) + { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLMeanVarianceNormalizationOptions* ml_options = + static_cast( + ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::MeanVarianceNormalization::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->mean_index = GetOperandIndex(ml_options->getMeanOr(nullptr)); + mojom_operator->variance_index = GetOperandIndex(ml_options->getVarianceOr(nullptr)); + mojom_operator->scale_index = GetOperandIndex(ml_options->getScaleOr(nullptr)); + mojom_operator->bias_index = GetOperandIndex(ml_options->getBiasOr(nullptr)); + mojom_operator->epsilon = ml_options->epsilon(); + mojom_operator->axes = ml_options->axes(); + mojom_operator->output_index = output_index; + auto operation_info = OperationInfo::NewMeanVarianceNormalization(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddPad(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + // Verify inputs exist and outputs do not yet exist. + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLPadOptionsInternal* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Pad::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->beginningPadding = ml_options->beginningPadding(); + mojom_operator->endingPadding = ml_options->endingPadding(); + mojom_operator->mode = BlinkPaddingModeToMojo(ml_options->mode().AsEnum()); + mojom_operator->value = ml_options->value(); + mojom_operator->output_index = output_index; + auto operation = OperationInfo::NewPad(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation)); +} + +void MojoModelInfo::AddFillSequence(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 0u, 1u); + + // Verify inputs exist and output does not yet exist. + const MLOperand* output = ml_operator->Outputs().front().Get(); + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLFillSequenceOptions* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::FillSequence::New(); + mojom_operator->start = ml_options->start(); + mojom_operator->delta = ml_options->delta(); + mojom_operator->output_index = output_index; + auto operation = OperationInfo::NewFillSequence(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation)); +} + +void MojoModelInfo::AddReduce(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + // Verify inputs exist and output does not yet exist. + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLReduceOptions* ml_options = + static_cast(ml_operator->Options()); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::Reduce::New(); + mojom_operator->operator_type = BlinkOperatorKindToMojoType(ml_operator->Kind()); + mojom_operator->axes = ml_options->axes(); + mojom_operator->keep_dimensions = ml_options->keepDimensions(); + mojom_operator->input_index = operand_index_map_.at(input); + mojom_operator->output_index = output_index; + auto operation = OperationInfo::NewReduce(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation)); +} + +void MojoModelInfo::AddResample2d(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + // Verify inputs exist and output does not yet exist. + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLResample2dOptions* ml_options = + static_cast(ml_operator->Options()); + + DCHECK(ml_options->hasScales()); + DCHECK(ml_options->hasAxes()); + + // Create mojom operator from JS blink type. + + static_assert( + V8MLInterpolationMode::kEnumSize == 2, + "A new enum was added - verify these mappings are still correct."); + static_assert( + uint32_t(V8MLInterpolationMode::Enum::kNearestNeighbor) == + uint32_t(ml::webnn::mojom::InterpolationMode::kNearestNeighbor)); + static_assert(uint32_t(V8MLInterpolationMode::Enum::kLinear) == + uint32_t(ml::webnn::mojom::InterpolationMode::kLinear)); + + auto mojom_operator = ml::webnn::mojom::blink::Resample2d::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->scales = ml_options->scales(); + mojom_operator->axes = ml_options->axes(); + mojom_operator->interpolation_mode = + static_cast( + ml_options->mode().AsEnum()); + mojom_operator->output_index = output_index; + + auto operation_info = OperationInfo::NewResample2d(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddTranspose(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 1u); + + const MLOperand* input = ml_operator->Inputs().front().Get(); + const MLOperand* output = ml_operator->Outputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + const MLTransposeOptions* ml_options = + static_cast(ml_operator->Options()); + + auto mojom_operator = ml::webnn::mojom::blink::Transpose::New(); + mojom_operator->input_index = operand_index_map_.at(input); + mojom_operator->output_index = output_index; + mojom_operator->permutation = ml_options->permutation(); + + auto operation_info = OperationInfo::NewTranspose(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddTriangularMatrix(const MLOperator* ml_operator) { + // TODO: +} + +void MojoModelInfo::AddElementWiseIf(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 3u, 1u); + + auto& inputs = ml_operator->Inputs(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) + { + return; + } + const MLOperand* condition = inputs[0].Get(); + const MLOperand* true_value = inputs[1].Get(); + const MLOperand* false_value = inputs[2].Get(); + const MLOperand* output = ml_operator->Outputs()[0].Get(); + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + auto mojom_operator = ml::webnn::mojom::blink::ElementWiseIf::New(); + mojom_operator->condition_index = operand_index_map_.at(condition); + mojom_operator->true_value_index = operand_index_map_.at(true_value); + mojom_operator->false_value_index = operand_index_map_.at(false_value); + mojom_operator->output_index = output_index; + auto operation_info = OperationInfo::NewElementWiseIf(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddDequantizeLinear(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 3u, 1u); + + auto& inputs = ml_operator->Inputs(); + if (!AreOperandsInIndexMap(inputs.data(), inputs.size())) + { + return; + } + const MLOperand* input = inputs[0].Get(); + const MLOperand* scale = inputs[1].Get(); + const MLOperand* zero_point = inputs[2].Get(); + const MLOperand* output = ml_operator->Outputs()[0].Get(); + DCHECK(!operand_index_map_.Contains(output)); + size_t output_index = AddOperandToModel(output); + + auto mojom_operator = ml::webnn::mojom::blink::DequantizeLinear::New(); + mojom_operator->input_index = operand_index_map_.at(input); + mojom_operator->scale_index = operand_index_map_.at(scale); + mojom_operator->zero_point_index = operand_index_map_.at(zero_point); + mojom_operator->output_index = output_index; + auto operation_info = OperationInfo::NewDequantizeLinear(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + +void MojoModelInfo::AddDynamicQuantizeLinear(const MLOperator* ml_operator) { + DCHECK_OPERATOR_INPUT_OUTPUT_COUNT(ml_operator, 1u, 3u); + + // Verify inputs exist and output does not yet exist. + auto& outputs = ml_operator->Outputs(); + const MLOperand* input = ml_operator->Inputs().front().Get(); + if (!operand_index_map_.Contains(input)) { + return; + } + DCHECK(!AreOperandsInIndexMap(outputs.data(), outputs.size())); + + // Create mojom operator from JS blink type. + auto mojom_operator = ml::webnn::mojom::blink::DynamicQuantizeLinear::New(); + mojom_operator->input_index = GetOperandIndex(input); + mojom_operator->output_index = AddOperandToModel(outputs[0]); + mojom_operator->output_scale_index = AddOperandToModel(outputs[1]); + mojom_operator->output_zero_point_index = AddOperandToModel(outputs[2]); + auto operation_info = OperationInfo::NewDynamicQuantizeLinear(std::move(mojom_operator)); + model_info_->operations.push_back(std::move(operation_info)); +} + void MojoModelInfo::FillConstantsWithArrayBuffer() { // Copy constant data to shared memory. base::CheckedNumeric constants_buffer_length(0); @@ -482,12 +1164,14 @@ size_t MojoModelInfo::AddOperandToModel(const MLOperand* output) { auto desc = ml::webnn::mojom::blink::OperandDescriptor::New(); desc->data_type = BlinkOperandTypeToMojo(output->Type()); desc->dimensions = output->Dimensions(); - // Add operand descriptor into model. - model_info_->operands.push_back(std::move(desc)); - // The index used to identify operand on the server side, each operation - // generate a output operand that will be inserted in a hash map with the - // MLOperand and index, the index is incremented by one. - size_t output_index = model_info_->operands.size() - 1; + // Add operand descriptor into model. The index used to identify operand on + // the server side, each operation generate a output operand that will be + // inserted in a hash map with the MLOperand and index, the index is + // incremented by one. + // Notice that, due to the limitation of WTF::HashMap (not allow 0 as key), + // the index is counted from 1; + size_t output_index = model_info_->operands.size() + 1; + model_info_->operands.insert(output_index, std::move(desc)); operand_index_map_.insert(output, output_index); return output_index; } diff --git a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.h b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.h index 2ec135188d39df..dfe0ab17ce34be 100644 --- a/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.h +++ b/third_party/blink/renderer/modules/ml/webnn/mojo_model_info.h @@ -39,10 +39,44 @@ class MojoModelInfo final : public GarbageCollected { void AddClamp(const MLOperator* clamp); void AddConv2d(const MLOperator* conv2d); - - // Element-wise binary operations + void AddConvTranspose2d(const MLOperator* conv2d); + + // Element-wise unary operations: + // - abs + // - neg + // - sin + // - cos + // - tan + // - erf + // - exp + // - log + // - sqrt + // - reciprocal + // - logicalNot + void AddElementWiseUnary(const MLOperator* ml_operator); + + // Element-wise unary operations with up to two parameters: + // - elu MLEluOptions + // - leakyRelu MLLeakyReluOptions + // - linear MLLinearOptions + // - hardSigmoid MLHardSigmoidOptions + // - softplus MLSoftplusOptions + void AddElementWiseUnaryTwoParameter(const MLOperator* ml_operator); + + // Elementwise binary operations: + // - add + // - sub + // - mul + // - div + // - pow + // - equal + // - greater + // - lesser void AddElementWiseBinary(const MLOperator* binary); + // Dot product operators: + // - gemm + // - matMul void AddGemm(const MLOperator* gemm); // Pooling operations @@ -50,10 +84,49 @@ class MojoModelInfo final : public GarbageCollected { void AddRelu(const MLOperator* relu); + // Reshaping operators (do not change the data, just interpretation): + // - reshape + // - squeeze + // - unsqueeze + // - flattenTo2d + // - identity (included here because it's a no-op) void AddReshape(const MLOperator* reshape); void AddSoftmax(const MLOperator* softmax); + void AddArgMinMax(const MLOperator* ml_operator); + void AddCast(const MLOperator* ml_operator); + void AddConcat(const MLOperator* ml_operator); + void AddSlice(const MLOperator* ml_operator); + void AddSplit(const MLOperator* ml_operator); + void AddExpand(const MLOperator* ml_operator); + void AddGather(const MLOperator* ml_operator); + void AddInstanceNormalization(const MLOperator* ml_operator); + void AddMeanVarianceNormalization(const MLOperator* ml_operator); + void AddPad(const MLOperator* ml_operator); + void AddFillSequence(const MLOperator* ml_operator); + + // Reduction operators: + // - reduceL1 + // - reduceL2 + // - reduceLogSum + // - reduceLogSumExp + // - reduceMax + // - reduceMean + // - reduceMin + // - reduceProduct + // - reduceSum + // - reduceSumSquare + void AddReduce(const MLOperator* ml_operator); + + void AddResample2d(const MLOperator* ml_operator); + void AddTranspose(const MLOperator* ml_operator); + void AddTriangularMatrix(const MLOperator* ml_operator); + void AddElementWiseIf(const MLOperator* ml_operator); + + void AddDequantizeLinear(const MLOperator* ml_operator); + void AddDynamicQuantizeLinear(const MLOperator* ml_operator); + void FillConstantsWithArrayBuffer(); ModelInfoPtr GetModelInfo(); @@ -61,6 +134,11 @@ class MojoModelInfo final : public GarbageCollected { private: // Add a operand to model which is output of the operation. size_t AddOperandToModel(const MLOperand* output); + + bool AreOperandsInIndexMap(const Member* operands, size_t operand_count) const; + + uint64_t GetOperandIndex(/*nullable*/ const MLOperand* operand) const; + // Hold all operands of model to index the operand. HeapHashMap, size_t> operand_index_map_; // All constant data will share a big shared memory, so hold the index of