Skip to content

Commit

Permalink
Revert "[Dispatch] extend CollapseDimensionsPass to more cases" (#19441)
Browse files Browse the repository at this point in the history
Regresses int8 sdxl with f8 attention. I suspect this is mostly a tuning
issue since this change only affects the shape of dispatches. I'll
reland this after investigating further (and most likely fixing tuning).



Reverts #19326
  • Loading branch information
IanWood1 authored Dec 10, 2024
1 parent eff0671 commit 7177c29
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,84 +101,6 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
// Matmul tuning
//===----------------------------------------------------------------------===//

transform.named_sequence @match_mmt_i8_i8_i32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
// transform.print %root {name = "Generic"} : !transform.any_op
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
^bb0(%lhs: tensor<?x?xi8>, %rhs: tensor<?x?xi8>, %out: tensor<?x?xi32>):
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%out : tensor<?x?xi32>) {
^bb0(%in: i8, %in_0: i8, %acc: i32):
%18 = arith.extsi %in : i8 to i32
%19 = arith.extsi %in_0 : i8 to i32
%20 = arith.muli %18, %19 : i32
%21 = arith.addi %acc, %20 : i32
linalg.yield %21 : i32
} -> tensor<?x?xi32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
transform.yield %root : !transform.any_op
}

transform.named_sequence @match_mmt_2048x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
%lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xi8> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
subgroup_m_count = 4, subgroup_n_count = 2,
reduction = [0, 0, 128],
workgroup = [128, 320, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [128, 4, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>
}>> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @match_mmt_2048x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
%lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xi8> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xi8> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
subgroup_m_count = 4, subgroup_n_count = 1,
reduction = [0, 0, 256],
workgroup = [128, 80, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [64, 4, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>
}>> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
%lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xi8> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
subgroup_m_count = 2, subgroup_n_count = 2,
reduction = [0, 0, 128],
workgroup = [64, 160, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [128, 2, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>
}>> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

//===----------------------------------------------------------------------===//
// Convolution tuning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -430,9 +352,6 @@ transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32
// TUNING_MATCH_BEGIN DO NOT REMOVE

// Matmul.
, @match_mmt_2048x10240x1280 -> @apply_op_config
, @match_mmt_2048x1280x5120 -> @apply_op_config
, @match_mmt_2048x1280x1280 -> @apply_op_config

// Convolution.

Expand Down
75 changes: 32 additions & 43 deletions compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Utils.h"
Expand Down Expand Up @@ -122,45 +121,36 @@ static SmallVector<ReassociationIndices> getCollapsibleLoops(Operation *op) {
(rDimsSet.count(prePos) && rDimsSet.count(nextPos));
};

ReassociationIndices range;
AffineExpr preExpr;
// Find the largest sequence of dimensions that are
// - Either preserved in all maps, or
// - are completely absent
// This sequence can be collapsed. To find the sequence,
// 1) For each indexing map, take the result expressions
// 2) Find a sequence of 2 that is found in all maps (or absent)
// 1) Take the result expressions of one of the indexing maps
// 2) Find a sequence of 2 that is found in all maps
// 3) Then take last element of this sequence and the next
// result expression, and check if this sequence of 2 is
// found in all maps. If so, add to sequence (to get a sequence of 3)
// and repeat till the last element of sequence and the next result
// expression is not found as a sequence in all maps.

llvm::SmallSetVector<unsigned, 8> seenLoops;
for (auto map : fusionInterfaceOp.getIndexingMapsArray()) {
ReassociationIndices range;
AffineExpr preExpr;

auto appendAndClearRange = [&]() {
if (range.size() > 1) {
contiguousLoops.push_back(range);
}
range.clear();
};

for (auto nextExpr : map.getResults()) {
unsigned position = cast<AffineDimExpr>(nextExpr).getPosition();
if (seenLoops.contains(position)) {
appendAndClearRange();
continue;
}
for (auto nextExpr :
fusionInterfaceOp.getIndexingMapsArray().front().getResults()) {
unsigned position = cast<AffineDimExpr>(nextExpr).getPosition();
if (!range.empty()) {
if (!hasAllMapsSameSequence(preExpr, nextExpr) ||
!hasSameIteratorType(preExpr, nextExpr)) {
appendAndClearRange();
if (range.size() > 1) {
contiguousLoops.push_back({range.begin(), range.end()});
}
range.clear();
}
range.push_back(position);
seenLoops.insert(position);
preExpr = nextExpr;
}
appendAndClearRange();
range.push_back(position);
preExpr = nextExpr;
}
if (range.size() > 1) {
contiguousLoops.push_back(range);
}

return contiguousLoops;
Expand Down Expand Up @@ -202,20 +192,21 @@ static bool isEligibleForCollapse(Operation *op) {
}

// TODO(#17948) GPU codegen fails when we collapse the dimensions of softmax.
auto isPossiblySoftmax = [&](OpOperand *operand) -> bool {
auto genericOperand = operand->get().getDefiningOp<linalg::GenericOp>();
if (!genericOperand) {
return false;
}

if (genericOperand.getNumReductionLoops() == 0) {
return false;
}

auto map = genericOp.getMatchingIndexingMap(operand);
return !map.isPermutation() && map.isProjectedPermutation();
};
if (llvm::any_of(genericOp.getDpsInputOperands(), isPossiblySoftmax)) {
if (llvm::any_of(genericOp.getDpsInputOperands(),
[&](OpOperand *operand) -> bool {
auto genericOperand =
operand->get().getDefiningOp<linalg::GenericOp>();
if (!genericOperand) {
return false;
}

if (genericOperand.getNumReductionLoops() == 0) {
return false;
}

return genericOp.getMatchingIndexingMap(operand)
.isProjectedPermutation();
})) {
return false;
}

Expand Down Expand Up @@ -624,7 +615,6 @@ hoistTensorReshapesOutOfDispatchRegion(
// 1. Get the slice of operations within `dispatchOp` that produce the yielded
// value.
BackwardSliceOptions sliceOptions;
sliceOptions.omitBlockArguments = true;
sliceOptions.filter = [&](Operation *op) {
return op->getParentOfType<IREE::Flow::DispatchRegionOp>();
};
Expand Down Expand Up @@ -878,7 +868,6 @@ collapseDimensionsForDispatch(IRRewriter &rewriter,
BackwardSliceOptions sliceOptions;
sliceOptions.inclusive = true;
sliceOptions.omitBlockArguments = true;
sliceOptions.omitUsesFromAbove = false;
sliceOptions.filter = [&](Operation *op) -> bool {
auto parentOp = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
return isEligibleForCollapse(op) && parentOp == regionOp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,40 +619,3 @@ util.func public @collapse_attention_with_truncf(%arg0: tensor<20x4096x16xf32>,
// CHECK: %[[TRUNC:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ATTN]] : tensor<20x4096x64xf32>
// CHECK: flow.return %[[TRUNC]] : tensor<20x4096x64xf16>

// -----

util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280xi8>, %12 : tensor<10240xi32>, %13 : tensor<10240xf32>) -> (tensor<2x32x32x10240xf16>) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%14 = tensor.empty() : tensor<2x32x32x10240xf16>
%15 = tensor.empty() : tensor<2x32x32x10240xi32>
%16 = linalg.fill ins(%c0_i32 : i32) outs(%15 : tensor<2x32x32x10240xi32>) -> tensor<2x32x32x10240xi32>
%dispatch = flow.dispatch.region -> (tensor<2x32x32x10240xf16>) {
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%10, %11 : tensor<2x32x32x1280xi8>, tensor<10240x1280xi8>) outs(%16 : tensor<2x32x32x10240xi32>) {
^bb0(%in: i8, %in_0: i8, %out: i32):
%19 = arith.extsi %in : i8 to i32
%20 = arith.extsi %in_0 : i8 to i32
%21 = arith.muli %19, %20 : i32
%22 = arith.addi %out, %21 : i32
linalg.yield %22 : i32
} -> tensor<2x32x32x10240xi32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%17, %12, %13 : tensor<2x32x32x10240xi32>, tensor<10240xi32>, tensor<10240xf32>) outs(%14 : tensor<2x32x32x10240xf16>) {
^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16):
%19 = arith.addi %in, %in_0 : i32
%20 = arith.sitofp %19 : i32 to f32
%21 = arith.mulf %20, %in_1 : f32
%22 = arith.truncf %21 : f32 to f16
linalg.yield %22 : f16
} -> tensor<2x32x32x10240xf16>
flow.return %18 : tensor<2x32x32x10240xf16>
}
util.return %dispatch : tensor<2x32x32x10240xf16>
}

// CHECK-LABEL: util.func public @collapse
// CHECK: %[[GEN0:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[GEN1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK: flow.return %[[GEN1]] : tensor<2048x10240xf16>

0 comments on commit 7177c29

Please sign in to comment.