Skip to content

Commit

Permalink
Allow Layouts to propogate to local_load
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Nov 21, 2024
1 parent e9db186 commit 1e98de3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
4 changes: 0 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ void lowerDistributedToShared(
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
assert(srcTy.getShape().size() <= 2 ||
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto elemTy = typeConverter->convertType(srcTy.getElementType());

auto smemBase = smemObj.getBase();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,8 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
}
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::HistogramOp,
triton::gpu::LocalAllocOp, triton::gpu::LocalStoreOp>(op);
triton::gpu::LocalAllocOp, triton::gpu::LocalLoadOp,
triton::gpu::LocalStoreOp>(op);
}

scf::ForOp replaceForOpWithNewSignature(
Expand Down
16 changes: 16 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2685,3 +2685,19 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:100", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-NOT: convert_layout
tt.func public @lift_convert_to_local_load(%arg0 : !tt.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> {
%1 = triton_gpu.local_load %arg0 : !tt.memdesc<2x1x32x4x4xi8, #shared, #triton_gpu.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked>
%2 = tt.trans %1 {order = array<i32: 0, 3, 2, 1, 4>} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1>
%3 = triton_gpu.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2>
tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,6 @@ struct AsyncCopyGlobalToLocalOpConversion
assert((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcLayout) &&
"Unexpected srcLayout in AsyncCopyGlobalToLocalOpConversion"));
auto resSharedLayout = cast<SharedEncodingAttr>(dstTy.getEncoding());
auto srcShape = srcTy.getShape();
assert((srcShape.size() <= 3) &&
"insert_slice_async: Unexpected rank of %src");

Value llDst = adaptor.getResult();
Value llSrc = adaptor.getSrc();
Expand Down

0 comments on commit 1e98de3

Please sign in to comment.