Skip to content

Commit

Permalink
quick fix on the storeDistributedToShared
Browse files Browse the repository at this point in the history
  • Loading branch information
jtang10 committed Dec 4, 2024
1 parent 6c7e670 commit 8db3652
Showing 1 changed file with 59 additions and 1 deletion.
60 changes: 59 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,65 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
RewriterBase &rewriter,
const TargetInfoBase &target, bool crossGrain,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
bool success;
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback;
// callback for every situation except the non-KContig dotOperand
// blocked->shared on AMD platform
perVectorCallback = [&](VectorType vecTy, Value vecAddr) {
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
srcVals = srcVals.drop_front(vecTy.getNumElements());

Value vec = undef(vecTy);
for (int i = 0; i < vals.size(); i++) {
vec = insert_element(vec, vals[i], i32_val(i));
}
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
};
// This section is only for inThreadTranspose for AMD path, where we want to
// transpose during the blocked->shared tranfer.
// For example, the thread-local register holds a [4, 8] section of matrix,
// where it is contiguous on the dim of 8. We want the perVectorCallback to
// access the column of 4 elements, 8 times, instead of row of 8 elements,
// 4 times like the callback above. For the specific example, the variables
// accessed or derived below will be the following:
// sizePerThread: [4, 8]
// order: [1, 0]
// numElemsPerIter: 4 x 8 = 32
// colIndex: initialized as 0, increment to 8 every time callback is called
// innerVecSize: 8, since it is the vector size of inner dimension
auto blockedEncoding = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
unsigned int colIndex = 0;
if (crossGrain && blockedEncoding) {
auto sizePerThread = blockedEncoding.getSizePerThread();
auto order = blockedEncoding.getOrder();
unsigned int numElemsPerIter = product<unsigned>(sizePerThread);
unsigned int innerVecSize = sizePerThread[order[0]];
perVectorCallback = [&](VectorType vecTy, Value vecAddr) {
Value vec = undef(vecTy);
auto startPos = colIndex / innerVecSize *
numElemsPerIter + // start pos of different iter
colIndex % innerVecSize; // start pos of single iter
for (int i = 0; i < vecTy.getNumElements(); i++) {
auto idx = startPos + i * innerVecSize; // iterate within a vector
vec = insert_element(vec, srcVals[idx], i32_val(i));
}
colIndex++;
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
};
}
success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, crossGrain, perVectorCallback);

Expand Down

0 comments on commit 8db3652

Please sign in to comment.