Skip to content

Commit

Permalink
Update xetile block op fallback pass (#1012)
Browse files Browse the repository at this point in the history
xetile block op fallback pass:
skip pass if pitch is not a multiple of tile width since mask is likely
required. current pass does not create correct mask.
  • Loading branch information
silee2 authored Jan 28, 2025
1 parent 7e5ec06 commit 8f74aff
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 10 deletions.
4 changes: 3 additions & 1 deletion include/imex/Dialect/XeTile/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ std::unique_ptr<mlir::Pass>
createXeTileBlockingPass(const std::string &device = "pvc");
std::unique_ptr<mlir::Pass> createXeTileWgToSgPass();
std::unique_ptr<mlir::Pass> createXeTileCanonicalizationPass();
std::unique_ptr<mlir::Pass> createXeTileBlockOpFallbackPass();
std::unique_ptr<mlir::Pass>
createXeTileBlockOpFallbackPass(const std::string &device = "pvc");

#define GEN_PASS_DECL_XETILEBLOCKING
#define GEN_PASS_DECL_XETILECANONICALIZATION
#define GEN_PASS_DECL_XETILEINITDUPLICATE
#define GEN_PASS_DECL_XETILEWGTOSG
#define GEN_PASS_DECL_XETILEBLOCKOPFALLBACK
#include <imex/Dialect/XeTile/Transforms/Passes.h.inc>

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions include/imex/Dialect/XeTile/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def XeTileBlockOpFallback : Pass<"xetile-blockop-fallback", "::mlir::gpu::GPUMod
"mlir::index::IndexDialect",
"mlir::memref::MemRefDialect",
"mlir::vector::VectorDialect"];
let options = [
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">
];
}

#endif // _XeTile_PASSES_TD_INCLUDED_
3 changes: 3 additions & 0 deletions include/imex/Utils/XeArch.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ struct LoadStore2DConfig {
llvm::SmallVector<int> array_length; // # of blocks to read/write memory
int restriction; // Max Width in bytes
GRFSize GRFDataSize; // Max GRF Data for load and store
int minPitch; // Min pitch in bytes
int pitchMultiple; // Pitch must be multiple in bytes of
// this value
};

/// This Base class provides uArch interface for defining HW supported configs
Expand Down
64 changes: 57 additions & 7 deletions lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "imex/Dialect/XeTile/IR/XeTileOps.h"
#include "imex/Dialect/XeTile/Transforms/Passes.h"
#include "imex/Utils/XeArch.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
Expand Down Expand Up @@ -80,8 +81,11 @@ static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) {

struct InitTileOpPattern final
: public mlir::OpRewritePattern<imex::xetile::InitTileOp> {
InitTileOpPattern(mlir::MLIRContext *context)
: OpRewritePattern<imex::xetile::InitTileOp>(context) {}
InitTileOpPattern(mlir::MLIRContext *context,
std::shared_ptr<imex::XeuArchInterface> uArch)
: OpRewritePattern<imex::xetile::InitTileOp>(context) {
uArchInterface = uArch;
}
mlir::LogicalResult
matchAndRewrite(imex::xetile::InitTileOp initTileOp,
mlir::PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -121,11 +125,19 @@ struct InitTileOpPattern final
auto elemBitwidth =
initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth();
auto pitchNumBytes = pitchNumElems * elemBitwidth / 8;
isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0);
auto config = uArchInterface->get2DPrefetchConfig(initTileOp.getOperation(),
elemBitwidth);
auto conf = config.value();
isValidPitch = (pitchNumBytes >= conf.minPitch) &&
(pitchNumBytes % conf.pitchMultiple == 0);
// If memspace is not SLM and pitch is valid, no need to rewrite
if (!isSLM && isValidPitch) {
return mlir::failure();
}
bool mayNeedMask = (pitchNumElems % tileTy.getShape().back() != 0);
if (mayNeedMask) {
return mlir::failure();
}
// Get flat shape size
int64_t flatSize = 1;
for (auto dim : srcShape) {
Expand Down Expand Up @@ -229,6 +241,9 @@ struct InitTileOpPattern final

return mlir::success();
}

private:
std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr;
};

struct LoadTileOpPattern final
Expand Down Expand Up @@ -414,30 +429,65 @@ struct SCFForOpPattern final : public mlir::OpRewritePattern<mlir::scf::ForOp> {
}
};

struct XeTileBlockOpFallbackPass final
class XeTileBlockOpFallbackPass final
: public imex::impl::XeTileBlockOpFallbackBase<XeTileBlockOpFallbackPass> {
public:
XeTileBlockOpFallbackPass() {
uArchInterface = std::make_shared<imex::XePVCuArch>();
}

XeTileBlockOpFallbackPass(const std::string &deviceName) {
if (deviceName == "pvc") {
uArchInterface = std::make_shared<imex::XePVCuArch>();
}
}

mlir::LogicalResult
initializeOptions(mlir::StringRef options,
mlir::function_ref<mlir::LogicalResult(const llvm::Twine &)>
errorHandler) override {
if (failed(Pass::initializeOptions(options, errorHandler)))
return mlir::failure();
if (device == "pvc")
uArchInterface = std::make_shared<imex::XePVCuArch>();
else
return errorHandler(llvm::Twine("Invalid device: ") + device);
return mlir::success();
}

void runOnOperation() override {
auto *context = &getContext();
mlir::Operation *op = getOperation();

if (!uArchInterface) {
op->emitOpError("Can not get GPU Arch Definition for given Arch param");
return signalPassFailure();
}

mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.useTopDownTraversal = true;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
patterns.add<InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern,
patterns.add<InitTileOpPattern>(context, uArchInterface);
patterns.add<LoadTileOpPattern, StoreTileOpPattern,
UpdateTileOffsetOpPattern, SCFForOpPattern>(context);
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr;
};

} // namespace blockopfallback

namespace imex {
std::unique_ptr<mlir::Pass> createXeTileBlockOpFallbackPass() {
return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>();
std::unique_ptr<mlir::Pass>
createXeTileBlockOpFallbackPass(const std::string &deviceName) {
return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>(
deviceName);
}
} // namespace imex
5 changes: 4 additions & 1 deletion lib/Utils/XeArch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ XePVCuArch::get2DLoadConfig(mlir::Operation *op, int element_data_size,
break;
}
loadParams.GRFDataSize.load = 2048;
loadParams.minPitch = 64;
loadParams.pitchMultiple = 16;
return loadParams;
}

Expand Down Expand Up @@ -188,7 +190,8 @@ XePVCuArch::get2DStoreConfig(int element_data_size) {
}

storeParams.GRFDataSize.store = 512;

storeParams.minPitch = 64;
storeParams.pitchMultiple = 16;
return storeParams;
}

Expand Down
15 changes: 14 additions & 1 deletion test/Dialect/XeTile/Transforms/block_op_fallback.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
// RUN: imex-opt --split-input-file --xetile-blockop-fallback %s -verify-diagnostics -o -| FileCheck %s
// RUN: imex-opt --split-input-file --xetile-blockop-fallback=device=pvc %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test_module {
// CHECK-LABEL: @test_pitch_not_multiple_of_tile_width
gpu.func @test_pitch_not_multiple_of_tile_width(%arg0: memref<512x250xf32>) {
// CHECK: %[[VAR0:.*]] = xetile.init_tile %arg0[0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32
%0 = xetile.init_tile %arg0 [0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32, #xetile.tile_attr<order = [1, 0]>>
// CHECK: %[[VAR1:.*]] = xetile.load_tile %[[VAR0]]
%1 = xetile.load_tile %0 : !xetile.tile<32x16xf32, #xetile.tile_attr<order = [1, 0]>> -> vector<32x16xf32>
gpu.return
}
}

// -----

gpu.module @test_module {
// CHECK-LABEL: @test_pitch_one_elems_and_offset_attr
Expand Down

0 comments on commit 8f74aff

Please sign in to comment.