Skip to content

Commit

Permalink
[BACKEND] Switch back to use llvm.load for shared memory load (triton…
Browse files Browse the repository at this point in the history
…-lang#4776)

When we don't have predicates we can use llvm.load. Using inline asm for
i8 types can cause inefficient code generation in llvm due to the
interaction with DAG legalizer.
  • Loading branch information
ThomasRaoux authored and bertmaher committed Dec 6, 2024
1 parent e369046 commit 7cd75ec
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 59 deletions.
62 changes: 13 additions & 49 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: convert_layout_blocked_blocked
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
// CHECK-: nvvm.barrier0
// CHECK-COUNT-8: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand Down Expand Up @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.inline_asm
// CHECK-SAME: st.shared
// CHECK: nvvm.barrier0
// CHECK: ld.shared
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
tt.return
}
Expand Down Expand Up @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
// CHECK-COUNT-128: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK-COUNT-8: ld.shared.v4.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
tt.return
}
Expand Down Expand Up @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32
// CHECK: llvm.load {{.*}} -> vector<4xi32>
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
tt.return
}
Expand All @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> i32
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down
43 changes: 33 additions & 10 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) {
}
}

static bool isConstantTruePred(Value pred) {
if (auto constOp = pred.getDefiningOp<LLVM::ConstantOp>()) {
return cast<IntegerAttr>(constOp.getValue()).getInt() != 0;
}
return false;
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const {
Expand Down Expand Up @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
.v(vec, /*predicate=*/vec > 1)
.b(elemBitwidth);

std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1 ? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);

Value load;
if (isConstantTruePred(pred)) {
Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth))
: Type(vec_ty(int_ty(elemBitwidth), vec));
load = load(resultTy, ptr);
if (vec > 1) {
Type structTy = struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth)));
Value structValue = undef(structTy);
for (int i = 0; i < vec; i++) {
structValue = insert_val(structTy, structValue,
extract_element(load, i32_val(i)), i);
}
load = structValue;
}
} else {
std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1
? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
}
SmallVector<Value> resultVals = unpackLLElements(loc, load, rewriter);
return packLLVector(loc, resultVals, rewriter);
}
Expand Down

0 comments on commit 7cd75ec

Please sign in to comment.