Skip to content

Commit

Permalink
[TOSA] Add logit, log1p, log10 and add promote type to unary fponly o…
Browse files Browse the repository at this point in the history
…ps (#3900)

* Add Torch to TOSA legalization for the following ops:
    - torch.aten.logit
    - torch.aten.log1p
    - torch.aten.log10
* Add promote to FP to FP-only TOSA ops like log and exp
* Update xfail with new e2e results
* Add new LIT tests to basic.mlir


Change-Id: I1cd7ec6964373dbaf08d419a806b3d735b830655

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Dec 2, 2024
1 parent 8711d3e commit 92d0f04
Show file tree
Hide file tree
Showing 3 changed files with 387 additions and 35 deletions.
224 changes: 200 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ using namespace mlir::torch::Torch;

namespace {

// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
// These legalizations are for unary ops with promoting input to floating-point
// datatypes only. There is no supported quantized integer mode for these.
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
Expand All @@ -51,17 +51,22 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");

if (isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
self);
return success();
} else {
auto resultTy = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

if (!isa<mlir::FloatType>(resultTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}
op, "Only floating-point datatype result types are supported");

// Non floating point inputs are not supported in TOSA so we cast the input
// to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
self = tosa::promoteType(rewriter, self, resultTy);

rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);

return success();
}
};

Expand Down Expand Up @@ -2922,24 +2927,32 @@ template <>
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
AtenLog2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();

// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");

auto outType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));

// If input is not a float type then cast it to output type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, outType);

// Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
ln2Shape, selfType.getElementType())
ln2Shape, outType.getElementType())
.value();

auto rcpOp =
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);

auto outType = getTypeConverter()->convertType(op.getType());
auto logOp =
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.getSelf());
auto logOp = rewriter.create<tosa::LogOp>(op.getLoc(), outType, self);
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
/*shift=*/0);

Expand Down Expand Up @@ -8025,6 +8038,166 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
}
};

// Legalization for aten.logit
template <>
LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
AtenLogitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Logit formula:
// result = log(zi / (1 - zi))
// Where: if eps is not None:
// zi = input clampled to [eps, 1 - eps]
// else:
// zi = input
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

bool isEpsNone = isa<Torch::NoneType>(op.getEps().getType());

double eps;
if (!isEpsNone && !matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
return rewriter.notifyMatchFailure(op,
"Non-const eps value is not supported");

auto zi = self;

// Clamp input to [eps, 1 - eps] when eps is not None
if (!isEpsNone) {
zi = rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self,
rewriter.getI64IntegerAttr(static_cast<int64_t>(eps)),
rewriter.getI64IntegerAttr(static_cast<int64_t>(1 - eps)),
rewriter.getF32FloatAttr(static_cast<float>(eps)),
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)))
.getResult();
}

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto oneMinusZi =
rewriter.create<tosa::SubOp>(op->getLoc(), resultType, one, zi);

auto oneMinusZiReciprocal = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), resultType, oneMinusZi.getResult());

auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), resultType, zi,
oneMinusZiReciprocal.getResult(),
/*shift=*/0);

auto result =
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, mulOp.getResult());

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.log1p
template <>
LogicalResult ConvertAtenOp<AtenLog1pOp>::matchAndRewrite(
AtenLog1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// log1p formula:
// yi = log(xi + 1)
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto addOp =
rewriter.create<tosa::AddOp>(op->getLoc(), resultType, self, one);

auto result =
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, addOp.getResult());

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.log10
template <>
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
AtenLog10Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// log10 formula (using log base changing formula since TOSA doesn't have a
// builtin log10 op):
// yi = log(xi) / log(10)
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto ten = tosa::getConstTensor<float>(rewriter, op, 10.0f, {}, resultElemTy)
.value();

auto logOfSelf = rewriter.create<tosa::LogOp>(op->getLoc(), resultType, self);

auto constType = RankedTensorType::get({}, resultElemTy);

auto logOfTen = rewriter.create<tosa::LogOp>(op->getLoc(), constType, ten);

auto reciprocalOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), constType, logOfTen.getResult());

auto result = rewriter.create<tosa::MulOp>(
op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(),
/*shift=*/0);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -8069,13 +8242,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {

RewritePatternSet patterns(context);

#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_FPONLY_PATTERN
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN

#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
Expand Down Expand Up @@ -8364,6 +8537,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
INSERT_ATENOP_PATTERN(AtenOuterOp);
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
35 changes: 24 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,22 @@
"RandIntPinMemoryModule_basic",
"RenormModuleFloat16_basic",
"SplitDimStaticModule_basic",
"Deg2radModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog1pModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMishModule_basic",
"L1LossMeanReductionModule_basic",
"L1LossNoReductionModule_basic",
"L1LossSumReductionModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"SoftplusModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
Expand Down Expand Up @@ -3416,6 +3432,8 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"IsInfiniteModule_basic",
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
Expand Down Expand Up @@ -3627,17 +3645,9 @@
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog1pModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
Expand Down Expand Up @@ -3755,6 +3765,7 @@
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -3822,7 +3833,6 @@
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
"SoftplusModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SortTensorDescending_basic",
Expand Down Expand Up @@ -3902,6 +3912,11 @@
}

ONNX_TOSA_XFAIL_SET = {
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"PowFloatIntModule_basic",
"PowIntFloatModule_basic",
"PowIntIntModule_basic",
"ColumnStack0dModule_basic",
"ColumnStack1dModule_basic",
"ColumnStackBasicIntModule_basic",
Expand Down Expand Up @@ -4311,7 +4326,6 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
Expand Down Expand Up @@ -4755,7 +4769,6 @@
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SoftmaxIntNonNoneDtypeModule_basic",
"SoftplusModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SortTensorDescending_basic",
Expand Down
Loading

0 comments on commit 92d0f04

Please sign in to comment.