Skip to content

Commit

Permalink
[TOSA] Add torch.prim.NumToTensor.Scalar float support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Jan 26, 2023
1 parent 23aa690 commit dde6e4c
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3459,13 +3459,27 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
// Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the
// tosa.
int64_t initValue;
if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue)))
return rewriter.notifyMatchFailure(
op, "unimplemented: input should be a torch constant int");
double doubleValue;
auto isFloat = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue));
int64_t intValue;
auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue));
if (!isFloat && !isInt)
return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant");
float floatValue = static_cast<float>(doubleValue);

auto outElemTy = resultType.getElementType();
if (outElemTy.isa<mlir::IntegerType>() || outElemTy.isF32()) {
DenseElementsAttr constAttr =
isInt ? DenseElementsAttr::get(resultType, {intValue})
: DenseElementsAttr::get(resultType, {floatValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
} else if (outElemTy.isF64()) {
auto resultF32 =
tosa::getConstTensor<float>(rewriter, op, floatValue, {}).value();
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, resultF32);
}

DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
return success();
}

Expand Down

0 comments on commit dde6e4c

Please sign in to comment.