diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 587f9cc41ad8..9fb1a9dcfec7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5313,9 +5313,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ } def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + NoSideEffect, AllowsTypeRefinement, HasValueSemantics, - ReadOnly + ReadOnly, ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index a1b65a42b643..ef084d2dc083 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -71,6 +71,25 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { }; } // namespace +// ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp +namespace { +template +class ConvertAtenUnaryConvertOp: public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.self()); + return success(); + } +}; +} // namespace + // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { @@ -307,6 +326,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::GT); + } else if (std::is_same()) { + compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( + op->getContext(), mhlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( @@ -980,6 +1002,75 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } // namespace +// AtenSizeIntOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSizeIntOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + auto dim = rewriter.create( + op.getLoc(), rewriter.getIndexType(), adaptor.dim()); + auto dimSize = rewriter.create( + op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), dimSize); + + return success(); +} +} // namespace + +// ValsemVariantAtenUniformOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + ValsemVariantAtenUniformOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto inputTy = adaptor.self().getType().template cast(); + auto loc = op.getLoc(); + if (!inputTy) { + op.emitError("input should be ranked tensor type."); + } + auto definingOp = op.self().getDefiningOp(); + auto shape = definingOp->getOperand(0); + SmallVector dimSizes; + getListConstructElements(shape, dimSizes); + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { + dSize = rewriter.create(loc, dSize).getResult(); + return dSize; + }); + + auto mhloShape = + rewriter.create(op.getLoc(), dimSizes); + + double fromDoubleValue, toDoubleValue; + if (!matchPattern(op.from(), m_TorchConstantFloat(&fromDoubleValue))) { + op.emitError("operand #1 should be scalar"); + } + if (!matchPattern(op.to(), m_TorchConstantFloat(&toDoubleValue))) { + op.emitError("operand #2 should be scalar"); + } + Value fromTensor = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(inputTy.getElementType(), fromDoubleValue)); + Value toTensor = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue)); + + auto outType = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + rewriter.replaceOpWithNewOp( + op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM); + return success(); +} +} void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1005,6 +1096,15 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); #undef INSERT_UNARY_FPONLY_PATTERN +#define INSERT_UNARY_CONVERT_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_UNARY_CONVERT_PATTERN(AtenContiguousOp); + INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeOp); + INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp); +#undef INSERT_UNARY_CONVERT_PATTERN + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -1038,6 +1138,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); @@ -1063,5 +1164,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenSizeIntOp); + INSERT_ATENOP_PATTERN(ValsemVariantAtenUniformOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f07ffb19d64e..4a0ff065fc18 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1155,6 +1155,47 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNativeDropoutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = op.input(); + Value prob = op.p(); + bool train = false; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); + + BaseTensorType inputType = input.getType().cast(); + if (!train) { + // TODO(yancey.yx): supports inference mode + return op.emitError( + "native_dropout does not support argument train is false"); + } + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + Value noneVal = rewriter.create(loc); + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = + rewriter.create(loc, inputType, maskedInput, oneMinusP); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + boolMask = rewriter.create( + loc, op.getResult(1).getType(), boolMask, one); + rewriter.replaceOp(op, {output, boolMask}); + return success(); + } +}; +} // namespace // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -2596,6 +2637,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index c74cc742aaca..1e6285a7efa3 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -139,6 +139,8 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); pm.addNestedPass(createConvertTorchToMhloPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(createConvertTorchToArithPass()); if (options.optimize) { // Clean up any non-canonical code introduced above.. diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir new file mode 100644 index 000000000000..b61a61b3bf83 --- /dev/null +++ b/test/Conversion/TorchToMhlo/dropout.mlir @@ -0,0 +1,47 @@ +// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.native_dropout.train( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { +// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[CST_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 +// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 +// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor +// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor +// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor +// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> +// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> +// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> +// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { +// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor +// CHECK: shape.assuming_yield %[[T19]] : tensor +// CHECK: } +// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> +// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor +// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor +// CHECK: return %[[T24]], %[[T26]] : tensor, tensor +func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { + %bool_true = torch.constant.bool true + %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> + return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> +} \ No newline at end of file diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index db04f201ea19..ce9b6f947fd3 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -360,9 +360,17 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> // CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]] +// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64 +// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]] // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]