From 5726318ec0c4af1b08cb3de7481dbe5d6d50556d Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:19:53 -0700 Subject: [PATCH 1/6] [CoreML EP] Fix ArgMaxOpBuilder::AddToModelBuilderImpl() nullptr Node access. (#21797) --- include/onnxruntime/core/graph/graph_nodes.h | 3 +- .../coreml/builders/impl/argmax_op_builder.cc | 9 ++-- .../coreml/builders/impl/cast_op_builder.cc | 5 -- .../providers/coreml/coreml_basic_test.cc | 51 +++++++++++++++++-- .../testdata/coreml_argmax_cast_test.onnx | 5 +- .../test/testdata/coreml_argmax_cast_test.py | 19 ++++--- .../coreml_argmax_unsupported_cast_test.onnx | 19 +++++++ 7 files changed, 87 insertions(+), 24 deletions(-) create mode 100644 onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx diff --git a/include/onnxruntime/core/graph/graph_nodes.h b/include/onnxruntime/core/graph/graph_nodes.h index 4fa2848a1d09e..aab5f2699d234 100644 --- a/include/onnxruntime/core/graph/graph_nodes.h +++ b/include/onnxruntime/core/graph/graph_nodes.h @@ -117,13 +117,14 @@ class ValidNodes { return (current_ != other.current_); } - void operator++() { + NodeIterator& operator++() { if (current_ < end_) { while (++current_ != end_) { if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false)) break; } } + return *this; } NodeIterator operator++(int) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index e9a8176c8349b..bc8b2d1a3505d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // 2. Otherwise, we add Argmax layer normally if (node.GetOutputEdgesCount() == 1) { auto it = node.OutputEdgesBegin(); - const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index())); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked in operator supported related, omit the check here - if (succ_node->OpType() == "Cast") { + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { // Skip the cast's input/argmax's output *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 70053c2c606a0..fc8879abbefb0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - if (node.GetInputEdgesCount() > 1) { - LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input."; - return false; - } - const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 0f068ba48d3d8..daa24db134114 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "core/common/logging/logging.h" +#include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" #include "core/providers/coreml/coreml_execution_provider.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/session/inference_session.h" @@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) { feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else @@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest", + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), - feeds); + feeds, + verification_params); +#else + TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx"); + +#if defined(__APPLE__) + std::vector dims_mul_x = {3, 2, 2}; + std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = std::make_shared(); + CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + const std::function graph_verifier = [](const Graph& graph) { + GraphViewer graph_viewer{graph}; + const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder(); + ASSERT_EQ(node_indices_in_order.size(), size_t{2}); + // second node should be an unsupported Cast + const auto* cast_node = graph.GetNode(node_indices_in_order[1]); + ASSERT_NE(cast_node, nullptr); + ASSERT_EQ(cast_node->OpType(), "Cast"); + ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider); + }; + + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some; + verification_params.graph_verifier = &graph_verifier; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider(), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif @@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { NameMLValMap feeds; feeds.insert(std::make_pair("Input3", ml_value)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx index db806f296aff3..931bd30dbe62f 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx @@ -1,4 +1,5 @@ -:� + +:� F Xargmax_output_int64argmax"ArgMax* axis�* @@ -15,4 +16,4 @@ F    -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.py b/onnxruntime/test/testdata/coreml_argmax_cast_test.py index acf24ac379065..6cc25311131a0 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.py +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.py @@ -1,16 +1,18 @@ import onnx from onnx import TensorProto, helper -# CoreML EP currently handles a special case for supporting ArgMax op -# Please see in /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and -# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc -# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type +# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32. +# Please see /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and +# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc. +# This script generates graphs for these cases: +# - An ArgMax followed by a supported Cast to int32 type +# - An ArgMax followed by an unsupported Cast to a type other than int32 -def GenerateModel(model_name): # noqa: N802 +def GenerateModel(model_name, cast_to_dtype): # noqa: N802 nodes = [ helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1), - helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type + helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype), ] graph = helper.make_graph( @@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802 helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]), ], [ # output - helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]), + helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]), ], ) @@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802 if __name__ == "__main__": - GenerateModel("coreml_argmax_cast_test.onnx") + GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32) + GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32) diff --git a/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx new file mode 100644 index 0000000000000..d5aea9110cbfa --- /dev/null +++ b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx @@ -0,0 +1,19 @@ + +:� +F +Xargmax_output_int64argmax"ArgMax* +axis�* +keepdims� +/ +argmax_output_int64Ycast"Cast* +to �CoreML_ArgMax_Cast_TestZ +X + + + +b +Y +  + + +B \ No newline at end of file From 44dcc3aafd4e308a1847552335ad85db1a1ec5e7 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Aug 2024 10:35:57 -0700 Subject: [PATCH 2/6] Replace "DML CPU" Allocator with onnxruntime::CpuAllocator (#21818) ### Description Replace "DML CPU" Allocator with onnxruntime::CpuAllocator ### Motivation and Context This allocator is being ignored by ORTExtensions and causes CPU memory to be treated as non-CPU memory and crash in SentencepieceTokenizer. In general it seems like this allocator is not used and can be handled just fine by the default allocator. --------- Co-authored-by: Sheil Kumar --- .../src/BucketizedBufferAllocator.cpp | 24 ------------------- .../src/ExecutionProvider.cpp | 4 +++- .../src/ExecutionProvider.h | 3 +-- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index db45908a2dda4..b1714a8220cd1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -223,28 +223,4 @@ namespace Dml { m_defaultRoundingMode = roundingMode; } - - CPUAllocator::CPUAllocator(OrtMemType memType) - : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML CPU", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, - memType - ) - ) - { - } - - void* CPUAllocator::Alloc(size_t size) - { - return onnxruntime::AllocatorDefaultAlloc(size); - } - - void CPUAllocator::Free(void* p) - { - return onnxruntime::AllocatorDefaultFree(p); - } - } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 043853ccae336..cb6fc165a932f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -239,7 +239,9 @@ namespace Dml std::make_unique(m_d3d12Device.Get())); m_context->SetAllocator(m_allocator); // CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators. - m_cpuInputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUInput); + OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator); + memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput; + m_cpuInputAllocator = std::make_shared(memoryInfo); } return std::vector{m_allocator, m_cpuInputAllocator,}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 29961288a51c5..c20969250fe84 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -25,7 +25,6 @@ namespace Dml class ReadbackHeap; class ExecutionContext; class BucketizedBufferAllocator; - class CPUAllocator; class ExecutionProvider; class ExecutionProviderImpl : public WRL::Base m_uploadHeap; std::unique_ptr m_readbackHeap; std::shared_ptr m_allocator; - std::shared_ptr m_cpuInputAllocator; + std::shared_ptr m_cpuInputAllocator; std::shared_ptr m_kernelRegistry; std::shared_ptr m_internalRegInfoMap; mutable uint64_t m_partitionKernelPrefixVal = 0; From 4af62918418a04bcd219471dfb8f8b27a1e8a8b5 Mon Sep 17 00:00:00 2001 From: duanshengliu <44742794+duanshengliu@users.noreply.github.com> Date: Sat, 24 Aug 2024 04:45:06 +0800 Subject: [PATCH 3/6] Refine `op_types_to_quantize` argument handling in matmul_4bits_quantizer.py (#21815) ### Description Refine `op_types_to_quantize` argument handling in matmul_4bits_quantizer.py ### Motivation and Context The default `op_types_to_quantize "MatMul"` will cause `tuple(args.op_types_to_quantize)` to become `('M', 'a', 't', 'M', 'u', 'l')`, which is not expected. --- .../python/tools/quantization/matmul_4bits_quantizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 975f82439c160..16ad36c48cc74 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -1062,7 +1062,6 @@ def parse_args(): ) parser.add_argument( "--op_types_to_quantize", - default="MatMul", type=str, nargs="+", choices=["MatMul", "Gather"], @@ -1089,7 +1088,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model quant_format = QuantFormat[args.quant_format] - op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else None + op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",) quant_axes = tuple(args.quant_axes) if args.quant_axes else None if os.path.exists(output_model_path): From 87165b92e94a9f8568d3fb1453715978111fcbec Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 24 Aug 2024 07:36:00 +0800 Subject: [PATCH 4/6] [js/webgpu] optimize MatmulNBits (#21747) ### Description See 2x speedup for phi3 on the integrated intel gpu with this optimization. The optimization is mainly to store input A's data into local variable instead of loading them from global memory each time when calculate them with B data. ### Motivation and Context --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 311 +++++++----------- 1 file changed, 124 insertions(+), 187 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index b63d253ebbb29..3f4617014e798 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { calculateTensorSizeInBytes, DataType } from '../../../wasm-common'; +import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; @@ -14,7 +14,6 @@ import { outputVariable, ShaderHelper, tensorTypeToWsglStorageType, - UniformsArrayType, } from './common'; // TODO support quantization bits not equal to 4 @@ -60,41 +59,27 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt export const createMatMulNBitsProgramInfo = ( inputs: readonly TensorView[], attributes: MatMulNBitsAttributes, - maxComputeWorkgroupSizes: [number, number, number], - maxComputeWorkgroupStorageSize: number, ): ProgramInfo => { const inputShape = inputs[0].dims; const aRank = inputShape.length; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); const dimAOuter = inputShape[aRank - 2]; const dimInner = attributes.k; const dimBOuter = attributes.n; const batchDims = inputShape.slice(0, aRank - 2); const batchSize = ShapeUtil.size(batchDims); - const blobSize = (attributes.blockSize / 8) * attributes.bits; + const blobSize = inputs[1].dims[2]; const blobSizeInWords = blobSize / 4; const dataType = inputs[0].dataType; - const outputNumber = getMaxComponents(dimAOuter); const aComponents = getMaxComponents(attributes.k); const bComponents = getMaxComponents(blobSizeInWords); - const workgroupOutputSize = calculateTensorSizeInBytes(dataType, dimAOuter * nBlocksPerCol)!; - const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); - const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; - const components = - !useBlockwiseMatMulNBits || maxNumberOfComponents >= 4 - ? getMaxComponents(dimBOuter) - : maxNumberOfComponents >= 2 && getMaxComponents(dimBOuter) >= 2 - ? 2 - : 1; + const components = getMaxComponents(dimBOuter); const outputShape = batchDims.concat([dimAOuter, dimBOuter]); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const outputNumber = dimAOuter > 1 && (dimBOuter / components) % 2 === 0 ? 2 : 1; + const dispatchSize = ShapeUtil.size(outputShape) / components / outputNumber; + + const workgroupSize = 64; - const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits - ? [] - : [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: attributes.blockSize }, - ]; + const programUniforms: ProgramUniform[] = []; const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); bShape.splice(-1, 1, blobSizeInWords / bComponents); @@ -106,6 +91,7 @@ export const createMatMulNBitsProgramInfo = ( } const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => { const inputRank = inputShapeTemp.length; const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); @@ -119,10 +105,6 @@ export const createMatMulNBitsProgramInfo = ( } const outputRank = outputShapeTemp.length; const output = outputVariable('output', inputs[0].dataType, outputRank, components); - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'block_size', type: 'u32' }, - ]; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const qDqDataType = (() => { @@ -138,187 +120,146 @@ export const createMatMulNBitsProgramInfo = ( } })(); - const processOneBlock = ` - for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { - ${b.indicesSet('b_indices', '2', 'word')}; - let b_data = ${b.getByIndices('b_indices')}; - for (var i: u32 = 0; i < ${bComponents}; i++) { - let b_value: u32 = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; - let b_mask: u32 = 0x0F0F0F0Fu; - let b_value_lower: vec4 = unpack4xU8(b_value & b_mask); - let b_value_upper: vec4 = unpack4xU8((b_value >> 4) & b_mask); - let b_quantized_values = ${qDqDataType}(${Array.from( + const processOneWord = (): string => { + let calcStr = ` + // reuse a data + var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)}; + var a_data: ${qDqDataType}; + for (var j: u32 = 0; j < ${8 / aComponents}; j++) { + a_data[j] = ${a.getByOffset('input_offset')}; + input_offset++; + } + `; + for (let c = 0; c < components * outputNumber; c++) { + calcStr += ` + b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`}; + b_value_lower = unpack4xU8(b_value & b_mask); + b_value_upper = unpack4xU8((b_value >> 4) & b_mask); + b_quantized_values = ${qDqDataType}(${Array.from( { length: 4 }, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, ).join(', ')}); - let b_dequantized_values = ${(() => { + b_dequantized_values = ${(() => { if (aComponents === 1) { return `${qDqDataType}(${Array.from( { length: 8 }, - (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`, + (_, i) => `(b_quantized_values[${i}] - ${zeroPoints ? `zero_point${c}` : 'zero_point'}) * scale${c}`, ).join(', ')});`; } else { - return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`; + return `(b_quantized_values - ${qDqDataType}(${Array(8) + .fill(`${zeroPoints ? `zero_point${c}` : 'zero_point'}`) + .join(',')})) * scale${c};`; } })()}; - // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 - for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) { - ${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)}; - ${a.indicesSet('a_indices', inputRank - 1, 'word_offset')}; - var input_offset = ${a.indicesToOffset('a_indices')}; - var a_data: ${qDqDataType}; - for (var j: u32 = 0; j < ${8 / aComponents}; j++) { - a_data[j] = ${a.getByOffset('input_offset')}; - input_offset++; - } - ${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${ - components > 1 ? '[c]' : '' - } += ${Array.from( - { length: 8 / aComponents }, - (_, i) => - `${ - aComponents === 1 - ? `a_data[${i}] * b_dequantized_values[${i}]` - : `dot(a_data[${i}], b_dequantized_values[${i}])` - }`, - ).join(' + ')}; - } - word_offset += ${8 / aComponents}; - } - }`; - const updateZeroPointIndex = zeroPoints - ? ` - zero_point_offset += 4; - if (zero_point_offset == 32) { - zero_point_offset = 0; - zero_point_index++; - zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; - }` - : ''; - - return useBlockwiseMatMulNBits - ? ` - var workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>; - ${shaderHelper.declareVariables(...inputVariables, output)} - ${shaderHelper.mainStart([nBlocksPerCol, 1, 1])} - var a_indices: ${a.type.indices}; - var block = local_id.x; - var col = workgroup_id.y; - var batch = workgroup_id.z; - ${a.indicesSet('a_indices', '0', 'batch')}; - // Two zero points are packed into one byte when uniforms.bits is 4. - for (var c: u32 = 0; c < ${components}; c++) { - let col_times_components_plus_c = col * ${components} + c; - ${ - zeroPoints - ? ` - var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2; - var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u); - var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u; - var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u; - var zero_point_nibble_offset: u32 = block & 0x1u; - var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); - var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` - : '' - } - var b_indices: ${b.type.indices}; - ${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')}; - // The scale and zero points are computed per block. - var scales_index = col_times_components_plus_c * ${nBlocksPerCol} + block; - let scale = ${scales.getByOffset('scales_index')}; + workgroup_shared[local_id.x * ${outputNumber} + ${Math.floor(c / components)}]${components > 1 ? `[${c % components}]` : ''} += ${Array.from( + { length: 8 / aComponents }, + (_, i) => + `${ + aComponents === 1 + ? `a_data[${i}] * b_dequantized_values[${i}]` + : `dot(a_data[${i}], b_dequantized_values[${i}])` + }`, + ).join(' + ')}; + `; + } + return calcStr; + }; + const prepareScaleAndZeroPoint = (): string => { + let calcStr = ` + var col_index = col * ${components}; + ${ + zeroPoints + ? ` + let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2; + var zero_point_byte_count: u32; + var zero_point_word_index: u32; + var zero_point_byte_offset: u32; + let zero_point_nibble_offset: u32 = block & 0x1u; + var zero_point_bits_offset: u32; + var zero_point_word: u32;` + : ` // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0}); - ${b.indicesSet('b_indices', '1', 'block')}; - var word_offset: u32 = block * ${attributes.blockSize / aComponents}; - var workgroup_shared_offset: u32 = block * ${dimAOuter}; - ${processOneBlock} - } - workgroupBarrier(); - var output_indices: ${output.type.indices}; - var elements_per_thread: u32 = ${Math.ceil(dimAOuter / nBlocksPerCol)}; - ${output.indicesSet('output_indices', '0', 'batch')}; - ${output.indicesSet('output_indices', outputRank - 1, 'col')}; - ${output.indicesSet('output_indices', outputRank - 2, 'local_id.x * elements_per_thread')}; - var output_offset = ${output.indicesToOffset('output_indices')}; - for (var m: u32 = 0u; m < elements_per_thread; m++) { - var row = m + local_id.x * elements_per_thread; - if (row < ${dimAOuter}) { - var output_value: ${output.type.value} = ${output.type.value}(0); - var workgroup_shared_offset: u32 = row; - for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) { - output_value += workgroup_shared[workgroup_shared_offset]; - workgroup_shared_offset += ${dimAOuter}; - } - ${output.setByOffset('output_offset', 'output_value')}; - output_offset += ${dimBOuter / components}; + let zero_point = ${dataType}(${8.0});` } - } - }` - : ` - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var output_values: array<${output.type.value}, ${outputNumber}>; - var output_indices = ${output.offsetToIndices('global_idx')}; - var col = ${output.indicesGet('output_indices', outputRank - 1)}; - var row = ${output.indicesGet('output_indices', outputRank - 2)}; - var a_indices: ${a.type.indices} = output_indices; - // Two zero points are packed into one byte because uniforms.bits <= 4. - // zero_point_offset is either 0 or 4. It is bit offset within one byte. - // TODO support zero_point_offset for bits > 4 - ${ - zeroPoints - ? ` - var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2); - var zero_point_index: u32 = zero_point_abs_offset / 4; - var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` - : '' - } - var scale_index = col * ${nBlocksPerCol * components}; - var b_indices: ${b.type.indices}; - for (var c: u32 = 0; c < ${components}; c++) { - ${b.indicesSet('b_indices', '0', `col * ${components} + c`)}; - var block_offset: u32 = 0; - for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { - // The scale and zero points are computed per block. - let scale = ${scales.getByOffset('scale_index')}; - // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); - ${b.indicesSet('b_indices', '1', 'block')}; - var word_offset: u32 = block_offset; - ${processOneBlock} - scale_index++; - ${updateZeroPointIndex} - block_offset += uniforms.block_size / ${aComponents}; - } - // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. + `; + for (let c = 0; c < components * outputNumber; c++) { + calcStr += ` + let scale${c} = ${scales.getByOffset(`col_index * nBlocksPerCol + block`)}; ${ zeroPoints - ? `if (zero_point_offset % 8 > 0) { - ${updateZeroPointIndex} - }` + ? ` + zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u); + zero_point_word_index = zero_point_byte_count >> 0x2u; + zero_point_byte_offset = zero_point_byte_count & 0x3u; + zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; + let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);` : '' } + col_index += 1;`; + } + return calcStr; + }; + const prepareBData = (): string => { + let calcStr = `col_index = col * ${components};`; + for (let c = 0; c < components * outputNumber; c++) { + calcStr += ` + let b${c}_data = ${b.getByIndices(`${b.type.indices}(col_index, block, word)`)}; + col_index += 1;`; + } + calcStr += ` + var b_value: u32; + let b_mask: u32 = 0x0F0F0F0Fu; + var b_value_lower: vec4; + var b_value_upper: vec4; + var b_quantized_values: ${qDqDataType}; + var b_dequantized_values: ${qDqDataType};`; + return calcStr; + }; + return ` + var workgroup_shared: array<${output.type.value}, ${outputNumber * workgroupSize}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([workgroupSize, 1, 1])} + let output_indices = ${output.offsetToIndices(`(global_idx / ${workgroupSize}) * ${outputNumber}`)}; + let col = output_indices[2]; + let row = output_indices[1]; + let batch = output_indices[0]; + let nBlocksPerCol = uniforms.b_shape[1]; + + for (var block = local_id.x; block < nBlocksPerCol; block += ${workgroupSize}) { + //process one block + var word_offset: u32 = block * ${attributes.blockSize / aComponents}; + ${prepareScaleAndZeroPoint()} + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${prepareBData()} + for (var i: u32 = 0; i < ${bComponents}; i++) { + ${processOneWord()} + word_offset += ${8 / aComponents}; + } } - for (var k: u32 = 0u; k < ${outputNumber}u; k++) { - ${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)}; - ${output.setByIndices('output_indices', 'output_values[k]')} + } + workgroupBarrier(); + + if (local_id.x < ${outputNumber}) { + var output_value: ${output.type.value} = ${output.type.value}(0); + var workgroup_shared_offset: u32 = local_id.x; + for (var b: u32 = 0u; b < ${workgroupSize}u; b++) { + output_value += workgroup_shared[workgroup_shared_offset]; + workgroup_shared_offset += ${outputNumber}; } + ${output.setByIndices(`${output.type.indices}(batch, row, col + local_id.x)`, 'output_value')}; + } }`; }; return { - name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + name: 'MatMulNBits', shaderCache: { - hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`, + hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components};${outputNumber};${workgroupSize}`, inputDependencies: Array(inputs.length).fill('rank'), }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType }], - name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', - dispatchGroup: useBlockwiseMatMulNBits - ? { x: 1, y: Math.ceil(dimBOuter / components), z: batchSize } - : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: { x: dispatchSize }, programUniforms, }), getShaderSource, @@ -327,11 +268,7 @@ export const createMatMulNBitsProgramInfo = ( export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); - const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); - context.compute( - createMatMulNBitsProgramInfo(context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize), - ); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); }; export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => From 9a70475622659bd9afb86847facc996d6a01c0b2 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Sun, 25 Aug 2024 01:01:08 -0400 Subject: [PATCH 5/6] [MIGraphX EP Support]Remove default noopt for Migraphx EP in Benchmark.py (#21843) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ripts (#58) ### Description Removes the heavy handed no opt for all MIGraphX using the benchmark.py scripts ### Motivation and Context Finding this hurts performance if we remove all optimizations. Let the fine tuning occur at the script level instead of a blanket NoOPT being selected Co-authored-by: Ted Themistokleous --- onnxruntime/python/tools/transformers/benchmark.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index dc0bb55212e28..4800c48744236 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -135,15 +135,6 @@ def run_onnxruntime( ) return results - if provider == "migraphx": - optimizer_info = OptimizerInfo.NOOPT - warm_up_repeat = 5 - if "MIGraphXExecutionProvider" not in onnxruntime.get_available_providers(): - logger.error( - "Please install onnxruntime-rocm package, and use a machine with GPU for testing gpu performance." - ) - return results - if optimizer_info == OptimizerInfo.NOOPT: logger.warning( f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied." From 983c4d57a4f1e35568545ab3b9971e83c26ecb7c Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Sun, 25 Aug 2024 19:05:11 -0700 Subject: [PATCH 6/6] Fix typo for react native pipeline (#21845) ### Description fix typo ### Motivation and Context [RN pipeline failing](https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=188&_a=summary) since #21578 with this error: ![image](https://github.com/user-attachments/assets/75e5b968-572f-42cc-9816-7940de464cfa) --- .../templates/android-dump-logs-from-steps.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/android-dump-logs-from-steps.yml b/tools/ci_build/github/azure-pipelines/templates/android-dump-logs-from-steps.yml index 2d91c605bf382..f8d7f6f1cae45 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-dump-logs-from-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-dump-logs-from-steps.yml @@ -6,7 +6,7 @@ parameters: steps: - task: CmdLine@2 - input: + inputs: script: | if [ -f $(Build.BinariesDirectory)/emulator.pid ]; then echo "Emulator is running." @@ -14,7 +14,7 @@ steps: else echo "Emulator is not running." fi - name: Determine if emulator is running + displayName: "Determine if emulator is running" - task: CmdLine@2 inputs: