Skip to content

Commit

Permalink
[AMD] Add alignment information to maskedLoad/maskedStore (#4816)
Browse files Browse the repository at this point in the history
I think we should always set the right alignment to the
`maskedload`/`maskedstore` instructions.
  • Loading branch information
giuseros authored Sep 30, 2024
1 parent 80947a2 commit 1df64d1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
19 changes: 13 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ struct LoadStoreConversionBase {
return axisAnalysisPass.getMaskAlignment(mask);
}

unsigned getPtrAlignment(Value ptr) const {
return axisAnalysisPass.getPtrAlignment(ptr);
}

protected:
const AMD::TargetInfo &targetInfo;
ModuleAxisInfoAnalysis &axisAnalysisPass;
Expand Down Expand Up @@ -193,7 +197,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNBits =
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const size_t valueElemNBytes = valueElemNBits / 8;
const int numVecs = numElems / vec;
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;

auto cacheMod = op.getCache();
SmallVector<Value> loadedVals;
Expand Down Expand Up @@ -230,8 +236,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
falseVal = v;
}

auto loadVal =
llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod);
Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal,
ptrAlignmentBytes, cacheMod);
for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec);
Expand Down Expand Up @@ -294,9 +300,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
vec = std::min(vec, maskAlign);
}

const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
const size_t valueElemNBits =
std::max<int>(8, valueElemTy.getIntOrFloatBitWidth());
const size_t valueElemNBytes = valueElemNBits / 8;
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;

auto cacheMod = op.getCache();
const int numVecs = elemsPerThread / vec;
Expand Down Expand Up @@ -328,7 +335,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
}
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod);
} // end vec
rewriter.eraseOp(op);
return success();
Expand Down
16 changes: 9 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
}

Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal, triton::CacheModifier cm) {
Value pred, Value falseVal, int64_t alignmentBytes,
triton::CacheModifier cm) {

// Try to emit llvm.intr.masked.load if we can. In theory the backend should
// be happier because we emit less branchy code to optimize. The backend will
// lower it down however it wants at some point.
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
if (alignmentBytes &&
(cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE)) {
// `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need
// to bitcast to `vector<1xelemTy>` (and back)
int64_t vecSize = getNumElements(elemTy);
Expand All @@ -203,7 +205,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
bool nt = (cm == triton::CacheModifier::CG);
Value vecData = rewriter.create<LLVM::MaskedLoadOp>(
loc, vecType, ptr, maskVal, falseVal, vecSize, nt);
loc, vecType, ptr, maskVal, falseVal, alignmentBytes, nt);
// If it is not a vector, remember to bitcast back to a scalar
vecData = bitcast(vecData, elemTy);
return vecData;
Expand Down Expand Up @@ -237,20 +239,20 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
}

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred, triton::CacheModifier cm) {
Value pred, int64_t alignmentBytes, triton::CacheModifier cm) {
// Try to emit llvm.intr.masked.store if we can. In theory the backend should
// be happier because we emit less branchy code to optimize. The backend will
// lower it down however it wants at some point.
if (cm == triton::CacheModifier::NONE) {
if (alignmentBytes && cm == triton::CacheModifier::NONE) {
// `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need
// to bitcast to `vector<1xelemTy>`
Type elemTy = val.getType();
int64_t vecSize = getNumElements(elemTy);
Type vecType = castToVectorType(elemTy);
val = bitcast(val, vecType);
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
auto op =
rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vecSize);
auto op = rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal,
alignmentBytes);
return;
}

Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
// Loads from shared or global memory with predication.
// `otherElems` is used to mask out the elements that are not loaded
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal,
Value pred, Value falseVal, int64_t alignmentBytes = 0,
triton::CacheModifier cm = triton::CacheModifier::NONE);

// Stores to shared or global memory with predication.
void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred,
Value pred, int64_t alignmentBytes = 0,
triton::CacheModifier cm = triton::CacheModifier::NONE);
} // namespace mlir::LLVM::AMD

Expand Down

0 comments on commit 1df64d1

Please sign in to comment.