diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 0491d2b5de4d..649f7e66746a 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1602 \ + --goldendispatch-rocm-unet 1598 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -242,17 +242,17 @@ jobs: --goldentime-rocm-unet-ms 80.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1602 \ + --goldendispatch-rocm-unet 1598 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ - --goldentime-rocm-punet-int8-fp16-ms 53 \ - --goldendispatch-rocm-punet-int8-fp16 1424 \ + --goldentime-rocm-punet-int8-fp16-ms 51 \ + --goldendispatch-rocm-punet-int8-fp16 1416 \ --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ - --goldentime-rocm-punet-int8-fp8-ms 53 \ - --goldendispatch-rocm-punet-int8-fp8 1704 \ + --goldentime-rocm-punet-int8-fp8-ms 51 \ + --goldendispatch-rocm-punet-int8-fp8 1696 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir index 16049fad2543..969f31ab68c2 100644 --- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir @@ -101,6 +101,142 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran // Matmul tuning //===----------------------------------------------------------------------===// +transform.named_sequence @match_mmt_i8_i8_i32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op +} + +transform.named_sequence @match_mmt_2048x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 128], + workgroup = [128, 320, 0]}>, + translation_info = #iree_codegen.translation_info + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_2048x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, + reduction = [0, 0, 256], + workgroup = [128, 80, 0]}>, + translation_info = #iree_codegen.translation_info + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 128], + workgroup = [64, 160, 0]}>, + translation_info = #iree_codegen.translation_info>}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_8192x640x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 8, subgroup_n_count = 1, + reduction = [0, 0, 64], + workgroup = [256, 64, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_8192x5120x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 4, + reduction = [0, 0, 64], + workgroup = [256, 128, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_8192x640x2560 (%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xi8> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 8, subgroup_n_count = 1, + reduction = [0, 0, 64], + workgroup = [256, 64, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + //===----------------------------------------------------------------------===// // Convolution tuning //===----------------------------------------------------------------------===// @@ -152,6 +288,65 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.yield %generic, %config : !transform.any_op, !transform.any_param } + transform.named_sequence @match_broadcast_rhs_mmt_Bx64x1280x2480(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x2480xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 0, 128], + workgroup = [1, 64, 160, 0]}>, + translation_info = #iree_codegen.translation_info> + }> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_broadcast_rhs_mmt_Bx4960x640x640(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 8, subgroup_n_count = 1, + reduction = [0, 0, 0, 64], + workgroup = [1, 256, 64, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_broadcast_rhs_mmt_Bx64x640x2480(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2480xi8> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, + reduction = [0, 0, 0, 128], + workgroup = [1, 32, 320, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + } + transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x5120x640(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value @@ -352,6 +547,12 @@ transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 // TUNING_MATCH_BEGIN DO NOT REMOVE // Matmul. + , @match_mmt_2048x10240x1280 -> @apply_op_config + , @match_mmt_2048x1280x5120 -> @apply_op_config + , @match_mmt_2048x1280x1280 -> @apply_op_config + , @match_mmt_8192x640x640 -> @apply_op_config + , @match_mmt_8192x5120x640 -> @apply_op_config + //, @match_mmt_8192x640x2560 -> @apply_op_config // Convolution. @@ -363,6 +564,10 @@ transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 // Carried over from SPX. , @match_broadcast_rhs_mmt_Bx1024x10240x1280 -> @apply_op_config , @match_broadcast_rhs_mmt_Bx1024x1280x1280 -> @apply_op_config + , @match_broadcast_rhs_mmt_Bx64x1280x2480 -> @apply_op_config + , @match_broadcast_rhs_mmt_Bx4960x640x640 -> @apply_op_config + //, @match_broadcast_rhs_mmt_Bx64x640x2480 -> @apply_op_config + // Contration. , @match_matmul_like_Bx20x1024x64x1280_i8xi8xi32 -> @apply_op_config diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index ba795789d8c5..f1fe77bb3734 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/DispatchCreation/Passes.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Utils.h" @@ -121,36 +122,45 @@ static SmallVector getCollapsibleLoops(Operation *op) { (rDimsSet.count(prePos) && rDimsSet.count(nextPos)); }; - ReassociationIndices range; - AffineExpr preExpr; // Find the largest sequence of dimensions that are // - Either preserved in all maps, or // - are completely absent // This sequence can be collapsed. To find the sequence, - // 1) Take the result expressions of one of the indexing maps - // 2) Find a sequence of 2 that is found in all maps + // 1) For each indexing map, take the result expressions + // 2) Find a sequence of 2 that is found in all maps (or absent) // 3) Then take last element of this sequence and the next // result expression, and check if this sequence of 2 is // found in all maps. If so, add to sequence (to get a sequence of 3) // and repeat till the last element of sequence and the next result // expression is not found as a sequence in all maps. - for (auto nextExpr : - fusionInterfaceOp.getIndexingMapsArray().front().getResults()) { - unsigned position = cast(nextExpr).getPosition(); - if (!range.empty()) { + + llvm::SmallSetVector seenLoops; + for (auto map : fusionInterfaceOp.getIndexingMapsArray()) { + ReassociationIndices range; + AffineExpr preExpr; + + auto appendAndClearRange = [&]() { + if (range.size() > 1) { + contiguousLoops.push_back(range); + } + range.clear(); + }; + + for (auto nextExpr : map.getResults()) { + unsigned position = cast(nextExpr).getPosition(); + if (seenLoops.contains(position)) { + appendAndClearRange(); + continue; + } if (!hasAllMapsSameSequence(preExpr, nextExpr) || !hasSameIteratorType(preExpr, nextExpr)) { - if (range.size() > 1) { - contiguousLoops.push_back({range.begin(), range.end()}); - } - range.clear(); + appendAndClearRange(); } + range.push_back(position); + seenLoops.insert(position); + preExpr = nextExpr; } - range.push_back(position); - preExpr = nextExpr; - } - if (range.size() > 1) { - contiguousLoops.push_back(range); + appendAndClearRange(); } return contiguousLoops; @@ -192,21 +202,20 @@ static bool isEligibleForCollapse(Operation *op) { } // TODO(#17948) GPU codegen fails when we collapse the dimensions of softmax. - if (llvm::any_of(genericOp.getDpsInputOperands(), - [&](OpOperand *operand) -> bool { - auto genericOperand = - operand->get().getDefiningOp(); - if (!genericOperand) { - return false; - } - - if (genericOperand.getNumReductionLoops() == 0) { - return false; - } - - return genericOp.getMatchingIndexingMap(operand) - .isProjectedPermutation(); - })) { + auto isPossiblySoftmax = [&](OpOperand *operand) -> bool { + auto genericOperand = operand->get().getDefiningOp(); + if (!genericOperand) { + return false; + } + + if (genericOperand.getNumReductionLoops() == 0) { + return false; + } + + auto map = genericOp.getMatchingIndexingMap(operand); + return !map.isPermutation() && map.isProjectedPermutation(); + }; + if (llvm::any_of(genericOp.getDpsInputOperands(), isPossiblySoftmax)) { return false; } @@ -615,6 +624,7 @@ hoistTensorReshapesOutOfDispatchRegion( // 1. Get the slice of operations within `dispatchOp` that produce the yielded // value. BackwardSliceOptions sliceOptions; + sliceOptions.omitBlockArguments = true; sliceOptions.filter = [&](Operation *op) { return op->getParentOfType(); }; @@ -868,6 +878,7 @@ collapseDimensionsForDispatch(IRRewriter &rewriter, BackwardSliceOptions sliceOptions; sliceOptions.inclusive = true; sliceOptions.omitBlockArguments = true; + sliceOptions.omitUsesFromAbove = false; sliceOptions.filter = [&](Operation *op) -> bool { auto parentOp = op->getParentOfType(); return isEligibleForCollapse(op) && parentOp == regionOp; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir index ae4146fd2b64..880f7f99e0c0 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir @@ -619,3 +619,40 @@ util.func public @collapse_attention_with_truncf(%arg0: tensor<20x4096x16xf32>, // CHECK: %[[TRUNC:.*]] = linalg.generic // CHECK-SAME: ins(%[[ATTN]] : tensor<20x4096x64xf32> // CHECK: flow.return %[[TRUNC]] : tensor<20x4096x64xf16> + +// ----- + +util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280xi8>, %12 : tensor<10240xi32>, %13 : tensor<10240xf32>) -> (tensor<2x32x32x10240xf16>) { + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %14 = tensor.empty() : tensor<2x32x32x10240xf16> + %15 = tensor.empty() : tensor<2x32x32x10240xi32> + %16 = linalg.fill ins(%c0_i32 : i32) outs(%15 : tensor<2x32x32x10240xi32>) -> tensor<2x32x32x10240xi32> + %dispatch = flow.dispatch.region -> (tensor<2x32x32x10240xf16>) { + %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%10, %11 : tensor<2x32x32x1280xi8>, tensor<10240x1280xi8>) outs(%16 : tensor<2x32x32x10240xi32>) { + ^bb0(%in: i8, %in_0: i8, %out: i32): + %19 = arith.extsi %in : i8 to i32 + %20 = arith.extsi %in_0 : i8 to i32 + %21 = arith.muli %19, %20 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22 : i32 + } -> tensor<2x32x32x10240xi32> + %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%17, %12, %13 : tensor<2x32x32x10240xi32>, tensor<10240xi32>, tensor<10240xf32>) outs(%14 : tensor<2x32x32x10240xf16>) { + ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16): + %19 = arith.addi %in, %in_0 : i32 + %20 = arith.sitofp %19 : i32 to f32 + %21 = arith.mulf %20, %in_1 : f32 + %22 = arith.truncf %21 : f32 to f16 + linalg.yield %22 : f16 + } -> tensor<2x32x32x10240xf16> + flow.return %18 : tensor<2x32x32x10240xf16> + } + util.return %dispatch : tensor<2x32x32x10240xf16> +} + +// CHECK-LABEL: util.func public @collapse +// CHECK: %[[GEN0:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK: %[[GEN1:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK: flow.return %[[GEN1]] : tensor<2048x10240xf16>