Skip to content

Commit

Permalink
[TOSA] Add torch.prim.NumToTensor.Scalar float support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Jan 18, 2023
1 parent 3f49ba9 commit ab4fdc6
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3350,13 +3350,25 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
// Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the
// tosa.
int64_t initValue;
if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue)))
return rewriter.notifyMatchFailure(
op, "unimplemented: input should be a torch constant int");
double doubleValue;
auto isFloat = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue));
int64_t intValue;
auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue));
if (!isFloat && !isInt)
return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant");
float floatValue = static_cast<float>(doubleValue);

auto outElemTy = resultType.getElementType();
if (outElemTy.isa<mlir::IntegerType>() || outElemTy.isF32()) {
DenseElementsAttr constAttr =
isInt ? DenseElementsAttr::get(resultType, {intValue})
: DenseElementsAttr::get(resultType, {floatValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
} else if (outElemTy.isF64()) {
return rewriter.notifyMatchFailure(op, "Float64 is not supported in tosa");
}

DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
return success();
}

Expand Down

0 comments on commit ab4fdc6

Please sign in to comment.