diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 70987a9b0d048..3836f5d2985e3 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3350,13 +3350,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Only supports integer operand type, because for the floating point operand // type result tensor has to be of type `f64` which is not supported in the // tosa. - int64_t initValue; - if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue))) - return rewriter.notifyMatchFailure( - op, "unimplemented: input should be a torch constant int"); + double doubleValue; + auto isFloat = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue)); + int64_t intValue; + auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue)); + if (!isFloat && !isInt) + return rewriter.notifyMatchFailure(op, + "Unable to extract the scalar constant"); + float floatValue = static_cast(doubleValue); + + auto outElemTy = resultType.getElementType(); + if (outElemTy.isa() || outElemTy.isF32()) { + DenseElementsAttr constAttr = + isInt ? DenseElementsAttr::get(resultType, {intValue}) + : DenseElementsAttr::get(resultType, {floatValue}); + rewriter.replaceOpWithNewOp(op, resultType, constAttr); + } else if (outElemTy.isF64()) { + return rewriter.notifyMatchFailure(op, "Float64 is not supported in tosa"); + } - DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue}); - rewriter.replaceOpWithNewOp(op, resultType, constAttr); return success(); }