Skip to content

Commit

Permalink
[Global] Check the dtype in pad to intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Jul 23, 2024
1 parent 446f5b9 commit 836a821
Showing 1 changed file with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,22 @@ static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) {
int64_t nSize = bounds[nDim];
int64_t kSize = bounds[kDim];

auto inpElemType =
cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType())
.getElementType();
auto kernelElemType =
cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
.getElementType();

// TODO: Generalize to other dimensions.
// Try to search for pad value and check only filter dimension is blocked.
SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
for (const GPUMatmulShapeType &intrinsic : intrinsics) {

if (inpElemType != intrinsic.aType && kernelElemType != intrinsic.bType) {
continue;
}

std::optional<int64_t> mPadding, nPadding, kPadding;
auto getPadding = [](int64_t value, int64_t padTo) {
return llvm::divideCeil(value, padTo) * padTo - value;
Expand Down

0 comments on commit 836a821

Please sign in to comment.