Skip to content

Commit

Permalink
Attach Fusion interface to linalg.softmax (#18550)
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 authored Sep 25, 2024
1 parent 6634f0f commit 672ae82
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ jobs:
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
Expand All @@ -242,7 +242,7 @@ jobs:
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
Expand Down Expand Up @@ -52,7 +54,6 @@ struct IREELinalgExtInlinerInterface : public DialectInlinerInterface {
};

// Used to register the LinalgFusionOpInterface with the linalg ops.
namespace {
template <typename ConcreteType>
struct LinalgFusionOpInterfaceAdapter
: public LinalgFusionOpInterface::ExternalModel<
Expand Down Expand Up @@ -103,6 +104,48 @@ struct LinalgFusionOpInterfaceAdapter
return inputMaps;
}
};

namespace {
struct SoftmaxFusionOpInterfaceAdapter
: public LinalgFusionOpInterface::ExternalModel<
SoftmaxFusionOpInterfaceAdapter, linalg::SoftmaxOp> {
public:
SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
Builder b(op->getContext());
return llvm::to_vector(llvm::map_range(
llvm::cast<linalg::SoftmaxOp>(op).getDpsInputs(),
[&b](Value operand) -> AffineMap {
auto rank = cast<ShapedType>(operand.getType()).getRank();
return b.getMultiDimIdentityMap(rank);
}));
}

SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
Builder b(op->getContext());
return llvm::to_vector(llvm::map_range(
llvm::cast<linalg::SoftmaxOp>(op).getDpsInits(),
[&b](Value operand) -> AffineMap {
auto rank = cast<ShapedType>(operand.getType()).getRank();
return b.getMultiDimIdentityMap(rank);
}));
}

AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return getIndexingMapsForResults(op)[result.getResultNumber()];
}

AffineMap getMatchingIndexingMap(mlir::Operation *op,
OpOperand *operand) const {
return getIndexingMapsForOperands(op)[operand->getOperandNumber()];
}

SmallVector<AffineMap> getIndexingMapsArray(mlir::Operation *op) const {
auto inputMaps = getIndexingMapsForOperands(op);
llvm::append_range(inputMaps, getIndexingMapsForResults(op));
return inputMaps;
}
};
} // namespace

template <typename... Args>
Expand All @@ -125,6 +168,8 @@ void IREELinalgExtDialect::initialize() {
registerOpsWithLinalgExtOpInterface<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
linalg::SoftmaxOp::attachInterface<SoftmaxFusionOpInterfaceAdapter>(*context);

addInterfaces<IREELinalgExtInlinerInterface>();

addAttributes<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,26 @@ util.func @mixed_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x32
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
// CHECK: util.return %[[DISPATCH1]]

util.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf16> {
%empty0 = tensor.empty() : tensor<2x16x32xf32>
%empty1 = tensor.empty() : tensor<2x16x32xf16>
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%empty0 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%1 : tensor<2x16x32xf32>) outs(%empty1 : tensor<2x16x32xf16>){
^bb0(%in : f32, %out : f16):
%3 = arith.truncf %in : f32 to f16
linalg.yield %3 : f16
} -> tensor<2x16x32xf16>
util.return %2 : tensor<2x16x32xf16>
}

// CHECK-LABEL: util.func public @softmax
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
// CHECK: %[[SOFTMAX:.+]] = linalg.softmax
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SOFTMAX]]
// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
// CHECK: util.return %[[DISPATCH1]]

0 comments on commit 672ae82

Please sign in to comment.