Skip to content

Commit

Permalink
Split SIToFP and UIToFP conversion patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Dec 12, 2024
1 parent a9f58de commit 9dd5a3d
Showing 1 changed file with 78 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -526,37 +526,93 @@ struct TruncateIOpConversion : public OpConversionPattern<arith::TruncIOp> {
}
};

template <typename OpTy, typename ExtOpTy, typename CastOpTy>
struct IntToFPOpConversion : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
struct SIToFPOpConversion : public OpConversionPattern<arith::SIToFPOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy srcOp, typename OpTy::Adaptor adaptor,
matchAndRewrite(arith::SIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = srcOp.getIn().getType();
auto input = srcOp.getIn();
auto srcType = input.getType();
auto dstType = srcOp.getResult().getType();
if (!dstType.isF32() ||
!(srcType.isSignedInteger() || srcType.isSignlessInteger())) {
auto resultType = getTypeConverter()->convertType(dstType);

if (!(dstType.isF32() || dstType.isF64())) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
Value input = srcOp.getIn();
auto resultType = this->getTypeConverter()->convertType(dstType);
if (llvm::isa<arith::SIToFPOp>(srcOp) &&
(srcType.isSignlessInteger(64) || srcType.isSignedInteger(64))) {
rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F32Op>(srcOp, resultType,
input);

if (srcType.isSignedInteger(32) || srcType.isSignlessInteger(32)) {
if (dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
input);
return success();
}
if (dstType.isF64()) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
}
if (srcType.isSignedInteger(64) || srcType.isSignlessInteger(64)) {
if (dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F32Op>(srcOp, resultType,
input);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F64Op>(srcOp, resultType,
input);
}
return success();
}

if (!(srcType.isSignlessInteger(32) || srcType.isSignedInteger(32))) {
if (srcType.getIntOrFloatBitWidth() < 32) {
input = rewriter.create<ExtOpTy>(
srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
} else {
if (srcType.getIntOrFloatBitWidth() < 32) {
input = rewriter.create<arith::ExtSIOp>(
srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
}

rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
input);
return success();
}
};

struct UIToFPOpConversion : public OpConversionPattern<arith::UIToFPOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::UIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto input = srcOp.getIn();
auto srcType = input.getType();
auto dstType = srcOp.getResult().getType();

if (!(dstType.isF32() || dstType.isF64())) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}

auto resultType = getTypeConverter()->convertType(dstType);
if (srcType.isUnsignedInteger(32) || srcType.isSignlessInteger(32)) {
if (dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
input);
return success();
}
if (dstType.isF64()) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
}
if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) {
if (dstType.isF32()) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}

rewriter.replaceOpWithNewOp<IREE::VM::CastUI64F64Op>(srcOp, resultType,
input);
return success();
}

if (srcType.getIntOrFloatBitWidth() < 32) {
input = rewriter.create<arith::ExtUIOp>(
srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
}

rewriter.replaceOpWithNewOp<CastOpTy>(srcOp, resultType, input);
rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
input);
return success();
}
};
Expand Down Expand Up @@ -749,12 +805,9 @@ void populateArithToVMPatterns(MLIRContext *context,
IREE::VM::MaxF64Op>>(typeConverter, context);

// Floating-point conversion ops.
patterns.insert<IntToFPOpConversion<arith::SIToFPOp, arith::ExtSIOp,
IREE::VM::CastSI32F32Op>,
IntToFPOpConversion<arith::UIToFPOp, arith::ExtUIOp,
IREE::VM::CastUI32F32Op>,
FPToSIOpConversion, FPToUIOpConversion, BitcastOpConversion>(
typeConverter, context);
patterns.insert<SIToFPOpConversion, UIToFPOpConversion, FPToSIOpConversion,
FPToUIOpConversion, BitcastOpConversion>(typeConverter,
context);

// Shift ops.
patterns
Expand Down

0 comments on commit 9dd5a3d

Please sign in to comment.