Skip to content

Commit

Permalink
Adjust UpcastMXFPOp op verification
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Oct 25, 2024
1 parent 27d403b commit 88d2739
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,26 @@ LogicalResult UpcastMXFPOp::verify() {
"all dimensions except the last must match between operands");
}

auto layoutX = xTy.getEncoding();
if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) {
auto dotEncoding =
dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding());
if (!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
auto layoutScale = scaleTy.getEncoding();
if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) {

auto blockedScale =
dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding());
if (!blockedScale) {
return emitOpError("Expected a BlockOperandEncoding for scales");
}
auto blockedScale = cast<BlockedEncodingAttr>(layoutScale);

if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
// same warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
}
}

return success();
}
Expand Down

0 comments on commit 88d2739

Please sign in to comment.