From 8f74affdbfa2e6dcf3260e46740da5eacf204357 Mon Sep 17 00:00:00 2001 From: Sang Ik Lee Date: Tue, 28 Jan 2025 11:51:05 -0800 Subject: [PATCH] Update xetile block op fallback pass (#1012) xetile block op fallback pass: skip pass if pitch is not a multiple of tile width since mask is likely required. current pass does not create correct mask. --- .../imex/Dialect/XeTile/Transforms/Passes.h | 4 +- .../imex/Dialect/XeTile/Transforms/Passes.td | 5 ++ include/imex/Utils/XeArch.h | 3 + .../XeTile/Transforms/BlockOpFallback.cpp | 64 +++++++++++++++++-- lib/Utils/XeArch.cpp | 5 +- .../XeTile/Transforms/block_op_fallback.mlir | 15 ++++- 6 files changed, 86 insertions(+), 10 deletions(-) diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index c24b30184..c97266963 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -40,12 +40,14 @@ std::unique_ptr createXeTileBlockingPass(const std::string &device = "pvc"); std::unique_ptr createXeTileWgToSgPass(); std::unique_ptr createXeTileCanonicalizationPass(); -std::unique_ptr createXeTileBlockOpFallbackPass(); +std::unique_ptr +createXeTileBlockOpFallbackPass(const std::string &device = "pvc"); #define GEN_PASS_DECL_XETILEBLOCKING #define GEN_PASS_DECL_XETILECANONICALIZATION #define GEN_PASS_DECL_XETILEINITDUPLICATE #define GEN_PASS_DECL_XETILEWGTOSG +#define GEN_PASS_DECL_XETILEBLOCKOPFALLBACK #include //===----------------------------------------------------------------------===// diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 6381c8887..c1128d4eb 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -111,6 +111,11 @@ def XeTileBlockOpFallback : Pass<"xetile-blockop-fallback", "::mlir::gpu::GPUMod "mlir::index::IndexDialect", "mlir::memref::MemRefDialect", "mlir::vector::VectorDialect"]; + let options = [ + Option<"device", "device", "std::string", + /*default=*/"\"pvc\"", + "gpu platform architecture where these ops are running"> + ]; } #endif // _XeTile_PASSES_TD_INCLUDED_ diff --git a/include/imex/Utils/XeArch.h b/include/imex/Utils/XeArch.h index 5d1f79517..a0862899a 100644 --- a/include/imex/Utils/XeArch.h +++ b/include/imex/Utils/XeArch.h @@ -55,6 +55,9 @@ struct LoadStore2DConfig { llvm::SmallVector array_length; // # of blocks to read/write memory int restriction; // Max Width in bytes GRFSize GRFDataSize; // Max GRF Data for load and store + int minPitch; // Min pitch in bytes + int pitchMultiple; // Pitch must be multiple in bytes of + // this value }; /// This Base class provides uArch interface for defining HW supported configs diff --git a/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp index 1cdee2266..e298642c7 100644 --- a/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp @@ -19,6 +19,7 @@ #include "imex/Dialect/XeTile/IR/XeTileOps.h" #include "imex/Dialect/XeTile/Transforms/Passes.h" +#include "imex/Utils/XeArch.h" #include "imex/Utils/XeCommon.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" @@ -80,8 +81,11 @@ static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) { struct InitTileOpPattern final : public mlir::OpRewritePattern { - InitTileOpPattern(mlir::MLIRContext *context) - : OpRewritePattern(context) {} + InitTileOpPattern(mlir::MLIRContext *context, + std::shared_ptr uArch) + : OpRewritePattern(context) { + uArchInterface = uArch; + } mlir::LogicalResult matchAndRewrite(imex::xetile::InitTileOp initTileOp, mlir::PatternRewriter &rewriter) const override { @@ -121,11 +125,19 @@ struct InitTileOpPattern final auto elemBitwidth = initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth(); auto pitchNumBytes = pitchNumElems * elemBitwidth / 8; - isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0); + auto config = uArchInterface->get2DPrefetchConfig(initTileOp.getOperation(), + elemBitwidth); + auto conf = config.value(); + isValidPitch = (pitchNumBytes >= conf.minPitch) && + (pitchNumBytes % conf.pitchMultiple == 0); // If memspace is not SLM and pitch is valid, no need to rewrite if (!isSLM && isValidPitch) { return mlir::failure(); } + bool mayNeedMask = (pitchNumElems % tileTy.getShape().back() != 0); + if (mayNeedMask) { + return mlir::failure(); + } // Get flat shape size int64_t flatSize = 1; for (auto dim : srcShape) { @@ -229,6 +241,9 @@ struct InitTileOpPattern final return mlir::success(); } + +private: + std::shared_ptr uArchInterface = nullptr; }; struct LoadTileOpPattern final @@ -414,30 +429,65 @@ struct SCFForOpPattern final : public mlir::OpRewritePattern { } }; -struct XeTileBlockOpFallbackPass final +class XeTileBlockOpFallbackPass final : public imex::impl::XeTileBlockOpFallbackBase { +public: + XeTileBlockOpFallbackPass() { + uArchInterface = std::make_shared(); + } + + XeTileBlockOpFallbackPass(const std::string &deviceName) { + if (deviceName == "pvc") { + uArchInterface = std::make_shared(); + } + } + + mlir::LogicalResult + initializeOptions(mlir::StringRef options, + mlir::function_ref + errorHandler) override { + if (failed(Pass::initializeOptions(options, errorHandler))) + return mlir::failure(); + if (device == "pvc") + uArchInterface = std::make_shared(); + else + return errorHandler(llvm::Twine("Invalid device: ") + device); + return mlir::success(); + } + void runOnOperation() override { auto *context = &getContext(); mlir::Operation *op = getOperation(); + if (!uArchInterface) { + op->emitOpError("Can not get GPU Arch Definition for given Arch param"); + return signalPassFailure(); + } + mlir::RewritePatternSet patterns(context); mlir::GreedyRewriteConfig config; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; config.useTopDownTraversal = true; config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps; - patterns.add(context, uArchInterface); + patterns.add(context); if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { return signalPassFailure(); } } + +private: + std::shared_ptr uArchInterface = nullptr; }; } // namespace blockopfallback namespace imex { -std::unique_ptr createXeTileBlockOpFallbackPass() { - return std::make_unique(); +std::unique_ptr +createXeTileBlockOpFallbackPass(const std::string &deviceName) { + return std::make_unique( + deviceName); } } // namespace imex diff --git a/lib/Utils/XeArch.cpp b/lib/Utils/XeArch.cpp index 649faaa3a..d013f20dc 100644 --- a/lib/Utils/XeArch.cpp +++ b/lib/Utils/XeArch.cpp @@ -155,6 +155,8 @@ XePVCuArch::get2DLoadConfig(mlir::Operation *op, int element_data_size, break; } loadParams.GRFDataSize.load = 2048; + loadParams.minPitch = 64; + loadParams.pitchMultiple = 16; return loadParams; } @@ -188,7 +190,8 @@ XePVCuArch::get2DStoreConfig(int element_data_size) { } storeParams.GRFDataSize.store = 512; - + storeParams.minPitch = 64; + storeParams.pitchMultiple = 16; return storeParams; } diff --git a/test/Dialect/XeTile/Transforms/block_op_fallback.mlir b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir index bcf4d76a2..021d2ca4b 100644 --- a/test/Dialect/XeTile/Transforms/block_op_fallback.mlir +++ b/test/Dialect/XeTile/Transforms/block_op_fallback.mlir @@ -1,4 +1,17 @@ -// RUN: imex-opt --split-input-file --xetile-blockop-fallback %s -verify-diagnostics -o -| FileCheck %s +// RUN: imex-opt --split-input-file --xetile-blockop-fallback=device=pvc %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @test_pitch_not_multiple_of_tile_width + gpu.func @test_pitch_not_multiple_of_tile_width(%arg0: memref<512x250xf32>) { + // CHECK: %[[VAR0:.*]] = xetile.init_tile %arg0[0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32 + %0 = xetile.init_tile %arg0 [0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32, #xetile.tile_attr> + // CHECK: %[[VAR1:.*]] = xetile.load_tile %[[VAR0]] + %1 = xetile.load_tile %0 : !xetile.tile<32x16xf32, #xetile.tile_attr> -> vector<32x16xf32> + gpu.return + } +} + +// ----- gpu.module @test_module { // CHECK-LABEL: @test_pitch_one_elems_and_offset_attr