From 3841712e9e7091092393f4427114a840d2fafcc7 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Nov 2024 21:15:13 +0000 Subject: [PATCH] Slice layout wtf --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 65313f577428..008d6872f7ee 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -206,6 +206,11 @@ SmallVector getShapePerCTATile(Attribute layout) { mlir::dyn_cast(layout)) { auto sizePerThread = distributedLayout.getSizePerThread(); auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); + // ThreadsPerWarp does not align with this function for slice layout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + } auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); assert(sizePerThread.size() == threadsPerWarp.size() && sizePerThread.size() == warpsPerCTA.size());