Skip to content

Commit

Permalink
Slice layout wtf
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent 5b7dc8b commit 3841712
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
auto sizePerThread = distributedLayout.getSizePerThread();
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
// ThreadsPerWarp does not align with this function for slice layout
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
}
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
assert(sizePerThread.size() == threadsPerWarp.size() &&
sizePerThread.size() == warpsPerCTA.size());
Expand Down

0 comments on commit 3841712

Please sign in to comment.