From 9dd5a3da0099e777a8cd6a7067c128657a74cbfd Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Fri, 13 Dec 2024 03:28:40 +0530 Subject: [PATCH] Split SIToFP and UIToFP conversion patterns --- .../VM/Conversion/ArithToVM/Patterns.cpp | 103 +++++++++++++----- 1 file changed, 78 insertions(+), 25 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index 3025462fa9f49..6d902e5496111 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -526,37 +526,93 @@ struct TruncateIOpConversion : public OpConversionPattern { } }; -template -struct IntToFPOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct SIToFPOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(OpTy srcOp, typename OpTy::Adaptor adaptor, + matchAndRewrite(arith::SIToFPOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = srcOp.getIn().getType(); + auto input = srcOp.getIn(); + auto srcType = input.getType(); auto dstType = srcOp.getResult().getType(); - if (!dstType.isF32() || - !(srcType.isSignedInteger() || srcType.isSignlessInteger())) { + auto resultType = getTypeConverter()->convertType(dstType); + + if (!(dstType.isF32() || dstType.isF64())) { return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } - Value input = srcOp.getIn(); - auto resultType = this->getTypeConverter()->convertType(dstType); - if (llvm::isa(srcOp) && - (srcType.isSignlessInteger(64) || srcType.isSignedInteger(64))) { - rewriter.replaceOpWithNewOp(srcOp, resultType, - input); + + if (srcType.isSignedInteger(32) || srcType.isSignlessInteger(32)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + if (dstType.isF64()) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + } + if (srcType.isSignedInteger(64) || srcType.isSignlessInteger(64)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + } else { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + } return success(); } - if (!(srcType.isSignlessInteger(32) || srcType.isSignedInteger(32))) { - if (srcType.getIntOrFloatBitWidth() < 32) { - input = rewriter.create( - srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); - } else { + if (srcType.getIntOrFloatBitWidth() < 32) { + input = rewriter.create( + srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); + } + + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } +}; + +struct UIToFPOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::UIToFPOp srcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto input = srcOp.getIn(); + auto srcType = input.getType(); + auto dstType = srcOp.getResult().getType(); + + if (!(dstType.isF32() || dstType.isF64())) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + + auto resultType = getTypeConverter()->convertType(dstType); + if (srcType.isUnsignedInteger(32) || srcType.isSignlessInteger(32)) { + if (dstType.isF32()) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + if (dstType.isF64()) { + return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + } + } + if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) { + if (dstType.isF32()) { return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } + + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); + } + + if (srcType.getIntOrFloatBitWidth() < 32) { + input = rewriter.create( + srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input); } - rewriter.replaceOpWithNewOp(srcOp, resultType, input); + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); return success(); } }; @@ -749,12 +805,9 @@ void populateArithToVMPatterns(MLIRContext *context, IREE::VM::MaxF64Op>>(typeConverter, context); // Floating-point conversion ops. - patterns.insert, - IntToFPOpConversion, - FPToSIOpConversion, FPToUIOpConversion, BitcastOpConversion>( - typeConverter, context); + patterns.insert(typeConverter, + context); // Shift ops. patterns